import argparse
import os
from tqdm import tqdm
import platform
import numpy as np
import statistics

import torch
from torch.utils.data import DataLoader, TensorDataset

import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils
from opencood.data_utils.datasets import build_dataset
from opencood.models.comm_modules.residual_vq import ResidualVQ

def iterate_dataset(data_loader, device):
    data_iter = iter(data_loader)
    while True:
        try:
            x = next(data_iter)
        except StopIteration:
            data_iter = iter(data_loader)
            x = next(data_iter)
        yield train_utils.to_device(x, device)

def prune_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument('--model_dir', type=str, required=False,
                        default='',
                        help='Codebook pruning path')
    parser.add_argument('--target_size', type=str, required=False, default='16,16,16',
                        help='Target size of pruned codebooks')

    opt = parser.parse_args()
    return opt

if __name__ == '__main__':

    opt = prune_parser()
    sys_arch = platform.system()
    target_size = [int(item) for item in opt.target_size.split(',')]

    hypes = yaml_utils.load_yaml(None, opt)

    assert hypes['model']['args']['residual_vq']['num_quantizers']== len(target_size) , \
        "The original and target sizes do not match"

    if sys_arch == 'Windows':
        sys_separator = '\\'
    elif sys_arch == 'Linux':
        sys_separator = '/'

    print('Dataset Building')
    # This should be modified
    opencood_dataset = build_dataset(hypes, visualize=True, train=True)
    print(f"{len(opencood_dataset)} samples found.")
    data_loader = DataLoader(opencood_dataset,
                             batch_size=1,
                             num_workers=8,
                             collate_fn=opencood_dataset.collate_batch_test,
                             shuffle=True,
                             pin_memory=False,
                             drop_last=False)

    print('Creating Model')
    model = train_utils.create_model(hypes)
    if torch.cuda.is_available():
        model.cuda()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Loading Model from checkpoint')
    saved_path = opt.model_dir
    _, model = train_utils.load_saved_model(saved_path, model, epoch=None)
    model.eval()

    print('Retrieve all image hidden features from dataset')
    img_feature_save_path = os.path.join(opt.model_dir, 'img_features.npy')
    if not os.path.isfile(img_feature_save_path):
        for i, batch_data in tqdm(enumerate(data_loader)):
            batch_data = train_utils.to_device(batch_data, device)
            img_feature = model(batch_data['ego'], return_img_features = True)
            img_feature_cp = img_feature.cpu().numpy()
            if i==0:
                img_features = img_feature_cp.copy()
            else:
                img_features = np.concatenate((img_features, img_feature_cp), axis=0)

        np.save(img_feature_save_path, img_features)
    else:
        img_features = np.load(img_feature_save_path)

    print('Start codebook pruning')
    full_state_dict = model.state_dict()
    residual_vq_state_dict = {
        key.replace('residual_vq.',''): value
        for key, value in full_state_dict.items()
        if key.startswith('residual_vq.')
    }
    residual_vq = ResidualVQ(dim=hypes['model']['args']['residual_vq']['input_dim'],
                                  accept_image_fmap=hypes['model']['args']['residual_vq']['accept_image_fmap'],
                                  codebook_size_ls=hypes['model']['args']['residual_vq']['codebook_size_ls'],
                                  num_quantizers=hypes['model']['args']['residual_vq']['num_quantizers'])
    residual_vq.load_state_dict(residual_vq_state_dict)
    residual_vq.to(device)

    prune_dataset = TensorDataset(torch.Tensor(img_features))
    data_loader = DataLoader(
        dataset=prune_dataset,
        batch_size=1024,
        shuffle=True
    )

    pruning_pre_loss_ls = []
    pruning_rec_loss_ls = []
    pre_rec_loss_ls = []
    next_rec_loss_ls = []

    for t in range(5):
        for i, batch_data in tqdm(enumerate(data_loader)):

            batch_data = train_utils.to_device(batch_data, device)
            output_dict = residual_vq(batch_data[0])
            cmt_loss = output_dict[2][0].tolist()
            pruning_pre_loss_ls.append(statistics.mean(cmt_loss))
            pre_rec_loss_ls.append((output_dict[0] - batch_data[0]).abs().mean().tolist())

    for k in range(len(target_size)):
        residual_vq.layers[k]._codebook.codebook_pruning_function(remain_size=target_size[k])

    for t in range(5):
        for i, batch_data in tqdm(enumerate(data_loader)):

            batch_data = train_utils.to_device(batch_data, device)
            output_dict = residual_vq(batch_data[0])
            cmt_loss = output_dict[2][0].tolist()
            pruning_rec_loss_ls.append(statistics.mean(cmt_loss))
            next_rec_loss_ls.append((output_dict[0] - batch_data[0]).abs().mean().tolist())

    new_residual_vq_dict = residual_vq.state_dict()
    suffixes = ['initted', 'cluster_size', 'embed_sum', 'embed']
    for i in range(len(target_size)):
        for suffix in suffixes:
            full_state_dict['residual_vq.layers.%d._codebook.%s'%(i,suffix)] = new_residual_vq_dict['layers.%d._codebook.%s'%(i,suffix)]

    torch.save(full_state_dict, os.path.join(opt.model_dir, 'latest.pth'))
    hypes['model']['args']['residual_vq']['codebook_size_ls'] = target_size
    yaml_utils.save_yaml(hypes, os.path.join(opt.model_dir, 'config.yaml'))
    for i in range(6):
        print(next_rec_loss_ls[-6+i])
    print('Finished')
