import argparse
import sys
import os
from math import ceil

import json
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from transformers import AutoImageProcessor, LlavaNextProcessor

from tqdm import tqdm

from model import VQVAE
from scheduler import CycleScheduler
import distributed as dist
from time import sleep

##########################################################################################################
##########################################################################################################
##########################################################################################################
def log_metrics(mse_list, cos_list, lr_list, epoch_list, i_list, iva_list, mse, cos, lr, epoch, i, iva):
    mse_list.append(mse)
    cos_list.append(cos)
    lr_list.append(lr)
    epoch_list.append(epoch)
    i_list.append(i)
    iva_list.append(iva)


def save_metrics(metric_dict, values_save_dir, plots_save_dir, save_name_preffix, block_size=6):
    ncols = 1
    nrows = len(metric_dict)
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * block_size * 4, nrows * block_size))

    for i, (metric_name, metric_list) in enumerate(metric_dict.items()):
        name = save_name_preffix + f'_{metric_name}'

        values_dir = os.path.join(values_save_dir, 'values')
        os.makedirs(values_dir, mode=0o777, exist_ok=True)
        np.save(os.path.join(values_dir, name + '.npy'), np.asarray(metric_list))

        plot_dir = os.path.join(plots_save_dir, 'plots')
        os.makedirs(plot_dir, mode=0o777, exist_ok=True)
        axes[i].plot(metric_list, label=metric_name)

        axes[i].legend()
    plt.savefig(os.path.join(plot_dir, name + '.png'))
    plt.close()


##########################################################################################################
##########################################################################################################
##########################################################################################################
def calculate_encoding_mask(input, encoding_inds):
    shape = input.shape[1], input.shape[2]
    mask = torch.zeros(shape).view(-1)
    mask[encoding_inds] = 1
    mask = mask.view(shape)
    return mask


def draw_in_out_cos(images,
                    reconstructed_images,
                    in_out_cos_list,
                    inputs,
                    encoding_inds_list,
                    examples=10, block_size=6, save_dir=None, save_name=None):
    examples = min(examples, len(images))
    ncols = 5
    nrows = examples
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * block_size, nrows * block_size))

    for i in range(examples):
        ###############################################################################
        image = images[i]
        axes[i, 0].imshow(image)

        reconstructed_image = reconstructed_images[i].cpu().permute(1, 2, 0)
        axes[i, 1].imshow((reconstructed_image - reconstructed_image.min()) / (reconstructed_image.max() - reconstructed_image.min()))

        ###############################################################################
        input, encoding_inds = inputs[i].cpu(), encoding_inds_list[i].cpu()
        encoding_mask = calculate_encoding_mask(input, encoding_inds)
        axes[i, 2].imshow(encoding_mask)

        ###############################################################################
        in_out_cos = in_out_cos_list[i].cpu()
        axes[i, 3].imshow(in_out_cos)

        ###############################################################################
        flatten_in_out_cos = in_out_cos.view(-1)
        axes[i, 4].plot(flatten_in_out_cos, alpha=0.5, label='in_out_cos')
        axes[i, 4].legend()

        if i == 0:
            axes[i, 0].set_title('image')
            axes[i, 1].set_title(f'reconstructed_image')
            axes[i, 2].set_title(f'gumbel mask')
            axes[i, 3].set_title(f'cosine similarity between\ninput and reconstruction')
            axes[i, 4].set_title(f'cosine similarity between\ninput and reconstruction\n(flatten version)')

        axes[i, 0].set_axis_off()
        axes[i, 1].set_axis_off()
        axes[i, 2].set_axis_off()
        axes[i, 3].set_axis_off()

    if save_dir and save_name:
        plt.savefig(os.path.join(save_dir, save_name))
        plt.close()


##########################################################################################################
##########################################################################################################
##########################################################################################################
def draw_input_norm(inputs, outputs, examples=10, block_size=6, save_dir=None, save_name=None):
    ncols = 1
    examples = min(examples, len(inputs))
    nrows = examples
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * block_size * 4, nrows * block_size))

    for i in range(examples):
        input_norm = inputs[i].cpu().norm(dim=0).view(-1)
        output_norm = outputs[i].cpu().norm(dim=0).view(-1)

        axes[i].plot(input_norm, alpha=0.5, label='input')
        axes[i].plot(output_norm, alpha=0.5, label='output')
        axes[i].legend()

        if i == 0:
            axes[i].set_title(f'norms of input and reconstructed features')

    if save_dir and save_name:
        plt.savefig(os.path.join(save_dir, save_name))
        plt.close()


##########################################################################################################
##########################################################################################################
##########################################################################################################
def val(model, eval_dataset, batch_size, device):
    model.eval()

    images = []
    interpolated_images = []
    sample = []
    numbers = [0, 1, 100, 200,
               5100, 5200, 5300,
               6100, 6200, 6300]
    for j in range(len(numbers)):
        if numbers[j] >= len(eval_dataset):
            break

        im, inter_im, fe = eval_dataset[numbers[j]]
        images.append(im)
        interpolated_images.append(inter_im.to(device)[None])
        sample.append(fe.to(device)[None])
    interpolated_images = torch.cat(interpolated_images, dim=0)
    sample = torch.cat(sample, dim=0)
    sample /= sample.norm(dim=1, keepdim=True)

    with torch.no_grad():
        out, out_sample, input, check_dict = model(sample, interpolated_images)
        loss = model.loss_function(interpolated_images, sample, out, out_sample, check_dict['input_vector_amount'])

    images = interpolated_images.permute(0, 2, 3, 1).cpu()
    images = (images - images.min()) / (images.max() - images.min())
    return images, interpolated_images, out, input, check_dict, loss


def train(epoch, stage, loader, model, optimizer, scheduler, device, eval_dataset, args, train_tuple, val_tuple, save_delta, batch_size):
    if dist.is_primary():
        loader = tqdm(loader)

    train_mse_list, train_cos_list, train_lr_list, train_epoch_list, train_i_list, train_iva_list = train_tuple
    val_mse_list, val_cos_list, val_lr_list, val_epoch_list, val_i_list, val_iva_list = val_tuple
    for i, (im, fe) in enumerate(loader):
        model.zero_grad()
        fe = fe.to(device)
        fe /= fe.norm(dim=1, keepdim=True)
        im = im.to(device)

        out, out_fe, input, check_dict = model(fe, im)
        loss = model.loss_function(im, fe, out, out_fe, check_dict['input_vector_amount'])
        loss['loss'].backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        #########################################################
        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]
            loader.set_description(
                (
                    f"epoch: {epoch + 1}; "
                    f"msei: {loss['msei'].item():.5f}; "
                    f"msef: {loss['msef'].item():.5f}; "
                    f"cos: {loss['cos'].item():.5f}; "
                    f"input_vector_amount: {loss['iva'].item():.5f}; "
                    f"lr: {lr:.7f}"
                )
            )


            iva_factor, epoches = args.input_vector_amounts[stage], args.epoch[stage]
            encoder_layers, n_crops = args.encoder_layers, args.n_crops
            sample_save_dir = make_save_path(args.sample_path, stage, iva_factor, epoches, encoder_layers, batch_size, n_crops)


            if i % save_delta == 0:
                log_metrics(train_mse_list, train_cos_list, train_lr_list, train_epoch_list, train_i_list, train_iva_list,
                            loss['msei'].item(), loss['cos'].item(), lr, epoch, i, check_dict['input_vector_amount'].item())

                train_values_save_dir=os.path.join(sample_save_dir, 'metrics')
                train_plots_save_dir=os.path.join(sample_save_dir, 'metrics')
                os.makedirs(train_values_save_dir, mode=0o777, exist_ok=True)
                os.makedirs(train_plots_save_dir, mode=0o777, exist_ok=True)
                save_metrics(
                    {
                        'mse': train_mse_list,
                        'cos': train_cos_list,
                        'lr': train_lr_list,
                        'epoch': train_epoch_list,
                        'i': train_i_list,
                        'iva': train_iva_list
                    },
                    values_save_dir=train_values_save_dir,
                    plots_save_dir=train_plots_save_dir,
                    save_name_preffix=f'train')


            if i % save_delta == 0:
                model.eval()

                val_images, interpolated_images, val_out, val_input, val_check_dict, val_loss = val(model, eval_dataset, fe.shape[0], device)
                log_metrics(val_mse_list, val_cos_list, val_lr_list, val_epoch_list, val_i_list, val_iva_list,
                            val_loss['msei'].item(), val_loss['cos'].item(), lr, epoch, i, val_check_dict['input_vector_amount'].item())

                val_values_save_dir=os.path.join(sample_save_dir, 'metrics')
                val_plots_save_dir=os.path.join(sample_save_dir, 'metrics')
                os.makedirs(val_values_save_dir, mode=0o777, exist_ok=True)
                os.makedirs(val_plots_save_dir, mode=0o777, exist_ok=True)
                save_metrics(
                    {
                        'mse': val_mse_list,
                        'cos': val_cos_list,
                        'lr': val_lr_list,
                        'epoch': val_epoch_list,
                        'i': val_i_list,
                        'iva': val_iva_list
                    },
                    values_save_dir=val_values_save_dir,
                    plots_save_dir=val_plots_save_dir,
                    save_name_preffix=f'val')

                os.makedirs(sample_save_dir, mode=0o777, exist_ok=True)

                in_out_cos_save_name = f'{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}_in-out-cos.png'
                draw_in_out_cos(val_images, val_check_dict['transposed_input'], val_check_dict['in_out_cos'], val_input, val_check_dict['encoding_inds'],
                                save_dir=sample_save_dir, save_name=in_out_cos_save_name)

                draw_input_norm_name = f'{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}_input-reconstruction-norm.png'
                draw_input_norm(interpolated_images, val_out, save_dir=sample_save_dir, save_name=draw_input_norm_name)

                model.train()

###########################################################################################################
###########################################################################################################
###########################################################################################################
###########################################################################################################
###########################################################################################################

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12, use_thumbnail=True):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=use_thumbnail, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

def load_from_pil(image, input_size=448, max_num=12, use_thumbnail=True):
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=use_thumbnail, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


class FeaturesDataset(Dataset):
    def __init__(self, features_json_path, mode='train', side_size=None, image_processor=None):
        with open(features_json_path, 'r') as map_file:
            self.map = list(json.load(map_file).items())
        self.mode = mode

        self.side_size = side_size
        self.image_processor = image_processor

    def __len__(self):
        return len(self.map)

    def __getitem__(self, idx):
        feature_path, (image_path, number) = self.map[idx]

        feature = torch.load(feature_path, map_location='cpu').to(torch.float32)

        image = Image.open(image_path).convert('RGB')
        processed_image = self.image_processor(image)[number]
        if processed_image.shape[-1] != self.side_size:
            interpolated_image = F.interpolate(processed_image[None],
                                               size=(self.side_size, self.side_size),
                                               mode='bilinear', align_corners=False)[0]
        else:
            interpolated_image = processed_image
        if self.mode == 'train':
            return interpolated_image, feature.permute(2, 0, 1) # tensor[3 x 224 x 224], tensor[768 x 14 x 14]
        else:
            return np.asarray(image), interpolated_image, feature.permute(2, 0, 1) # array(427, 640, 3), tensor[3 x 224 x 224], tensor[768 x 14 x 14]


def calculat_num_parameters(model):
    params = 0
    for n, p in model.named_parameters():
        print(f'{n}: {np.prod(p.shape)}')
        params += np.prod(p.shape)
    print(f'NUM PARAMETERS: {params}')


def make_save_path(save_path, stage, iva_factor, epoches, encoder_layers, batch_size, n_crops):
    save_path = os.path.join(save_path,
                             f'stage-{stage}_iva-{iva_factor}_epoches-{epoches}_'
                             f'encoder-layers-{encoder_layers}_batch_size-{batch_size}_n_crops-{n_crops}')
    return save_path


def main(args):
    device = args.device

    args.distributed = dist.get_world_size() > 1
    vision_tower_name = args.vision_tower_name
    cache_dir = args.cache_dir
    side_size = args.side_size

    batch_size = args.batch_size
    dataset = FeaturesDataset(args.features_json_path, mode='train', image_processor=load_from_pil, side_size=side_size)
    eval_dataset = FeaturesDataset(args.val_features_json_path, mode='val', image_processor=load_from_pil, side_size=side_size)
    sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed)
    loader = DataLoader(dataset, batch_size=batch_size // args.n_gpu, sampler=sampler, num_workers=16)

    for stage in range(len(args.epoch)):
        iva_factor, epoches = args.input_vector_amounts[stage], args.epoch[stage]
        encoder_layers, n_crops = args.encoder_layers, args.n_crops
        save_path = make_save_path(args.save_path, stage, iva_factor, epoches, encoder_layers, batch_size, n_crops)

        current_w_path = os.path.join(save_path, f'vqvae-{str(epoches).zfill(3)}.pt')
        if os.path.exists(current_w_path):
            continue

        model = VQVAE(iva_factor, encoder_layers, vision_tower_name, cache_dir).to(device)
        if stage > 0:
            prev_iva_factor, prev_epoches = args.input_vector_amounts[stage - 1], args.epoch[stage - 1]
            prev_save_path = make_save_path(args.save_path, stage - 1, prev_iva_factor, prev_epoches, encoder_layers, batch_size, n_crops)

            w_path = os.path.join(prev_save_path, f'vqvae-{str(prev_epoches).zfill(3)}.pt')
            print('--------------->: ', w_path)
            while not os.path.exists(w_path):
                sleep(10)
                print(f'Not exists: {w_path}')
            model.load_state_dict(torch.load(w_path, weights_only=True), strict=False)
        calculat_num_parameters(model)

        optimizer = optim.Adam(model.parameters(), lr=args.lr)
        scheduler = None
        if args.sched == "cycle":
            scheduler = CycleScheduler(
                optimizer,
                args.lr,
                n_iter=len(loader) * args.epoch[stage],
                momentum=None,
                warmup_proportion=0.05,
            )


        train_tuple = [], [], [], [], [], []
        val_tuple = [], [], [], [], [], []
        for i in range(epoches):
            train(i, stage, loader, model, optimizer, scheduler, device, eval_dataset, args, train_tuple, val_tuple, args.save_delta, batch_size)

            if i % 1 == 0:
                os.makedirs(save_path, mode=0o777, exist_ok=True)
                if dist.is_primary():
                    torch.save(model.state_dict(), os.path.join(save_path, f'vqvae-{str(i + 1).zfill(3)}.pt'))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_gpu", type=int, default=1) #4, 8

    port = (
        2 ** 15
        + 2 ** 14
        + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
    )
    parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}")

    # training parameters
    parser.add_argument("--vision_tower_name", type=str, default="OpenGVLab/InternVL3-2B")
    parser.add_argument("--epoch", type=list, default=[60, 20, 20]) #[40, 40, 40]
    parser.add_argument("--encoder_layers", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=32) #32
    parser.add_argument("--input_vector_amounts", type=list, default=[0.6, 0.6, 0.6]) #[0.6, 0.4, 0.2]
    parser.add_argument("--side_size", type=int, default=512)
    parser.add_argument("--lr", type=float, default=5e-6) #1e-5 #1e-4 for all patches, 3e-4 default

    parser.add_argument("--n_crops", type=int, default=5)
    parser.add_argument("--sched", type=str, default='cycle')

    # data parameters
    parser.add_argument("--device", type=str, default='cuda')

    # cache dir for image_processor
    parser.add_argument("--cache_dir", type=str, default=None)

    # pathes for dataset
    parser.add_argument("--features_json_path", type=str, default='../../../data/OpenGVLab-InternVL3-2B_mlp/map.json')
    parser.add_argument("--val_features_json_path", type=str, default='../../../data/OpenGVLab-InternVL3-2B_mlp/map_val.json')

    # pathes for logging
    parser.add_argument("--save_path", type=str, default='./checkpoints')
    parser.add_argument("--sample_path", type=str, default='./samples')

    parser.add_argument("--save_delta", type=int, default=500)

    args = parser.parse_args()

    print(args)

    dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,))
