from __future__ import division
from __future__ import print_function
from __future__ import with_statement

import argparse
from importlib import import_module as impm
import logging
import numpy as np
import os
import random
import imageio
import pickle

import torch
import torchaudio

import _init_paths
from configs import cfg
from configs import update_config

from scipy.signal import oaconvolve
from scipy.interpolate import interp1d

from libs.utils import misc

from moviepy.editor import VideoFileClip, AudioFileClip
import soundfile as sf

from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm


def concat_images(image_names):
    image_files = []
    for index in np.arange(2):
        image_files.append(Image.open(image_names[index]))
    w1, h1 = image_files[0].size
    w2, h2 = image_files[1].size
    target = Image.new('RGB', (w1+w2, max(h1, h2)), (255,255,255))
    target.paste(image_files[0], (0, 0))
    target.paste(image_files[1], (w1, 0))
    return target

def normalize(sig):
    ir = np.copy(sig)
    max_value = abs(ir).max()
    return ir / max_value

# setting

def parse_args():
    parser = argparse.ArgumentParser(description='Neural Acoustic')
    parser.add_argument(
        '--cfg',
        dest='yaml_file',
        help='experiment configure file name, e.g. configs/base_config.yaml',
        required=True,
        type=str)
    parser.add_argument(
        "--apt",
        type=str,
        default='apartment_1'
    )
    parser.add_argument(
        "--music",
        type=str,
        default='music2'
    )
    parser.add_argument(
        "--source",
        type=str,
        default='begin'
    )
    parser.add_argument(
        'opts',
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
          
    return args


args = parse_args()
apt = args.apt
music_type = args.music
source_type = args.source

update_config(cfg, args)
ngpus_per_node = torch.cuda.device_count()

# torch seed
seed = cfg.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


if cfg.device == 'cuda':
    torch.cuda.set_device(0)
device = torch.device(cfg.device)

video_name = f'demo/{apt}_{music_type}_{source_type}_test_pdb.mp4'
pickle_file = f'demo/sound-spaces/data/NERF_scene_observations/replica/{apt}_dump_video_24_1_3.pkl'
with open(pickle_file, 'rb') as file:
    arr = pickle.load(file)

keys = sorted(list(arr.keys()))
writer = imageio.get_writer(video_name, fps=24)

max_len = {'apartment_1': 8256, 'apartment_2': 7727, 'frl_apartment_2': 9564, 'frl_apartment_4': 9375, 'office_4': 7118, 'room_2': 7829}

max_len = max_len[apt]
# load bounces
data_root = f'data/matterport_data/{apt}/'
meta_points = np.loadtxt(os.path.join(data_root, 'points.txt'))
mesh_points = np.loadtxt(os.path.join(data_root, 'mesh.xyz'))
points = {int(k): [x, y] for k, x, y in meta_points[..., :3]}
height = meta_points[0, 3]
h_range = 1.5
space_height = 150
range = h_range / (space_height / (height - mesh_points[..., 2].min()))
bounces = mesh_points[(mesh_points[..., 2]>height-range/2) & (mesh_points[..., 2]<height+range/2)][..., :2]
patches = np.array(list(set([tuple([round(x, 0), round(y, 0)]) for x, y in bounces]))).reshape(-1, 2)
if 'office' in apt or 'room_2' in apt:
    patches = np.array(list(set([tuple([round(x, 1), round(y, 1)]) for x, y in bounces]))).reshape(-1, 2)
    patches = np.array(patches)[::5]

source_listener_pts = np.array([[x, y] for k, x, y in meta_points[..., :3]])
if 'office' in apt:
    all_pts = source_listener_pts
else:
    all_pts = np.concatenate([source_listener_pts, patches],axis=0)

max_pos = np.array([np.max(all_pts[:, 0]), np.max(all_pts[:, 1])])
min_pos = np.array([np.min(all_pts[:, 0]), np.min(all_pts[:, 1])])
norm_patches = ((patches - min_pos) / (max_pos - min_pos) - 0.5) * 2.0

model = getattr(impm(cfg.render.file), 'build_render')(cfg.model.file, n_bins=max_len, patches=norm_patches)
model.eval()
model = torch.nn.DataParallel(model).to(device)


model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

resume_path = f'data/{apt}_long.pth'
if os.path.exists(resume_path):
    checkpoint = torch.load(resume_path, map_location='cpu')
    # resume
    if 'state_dict' in checkpoint:
        model.module.load_state_dict(checkpoint['state_dict'], strict=True)
        logging.info(f'==> model loaded from {resume_path} \n')
    print(f'==> model loaded from {resume_path} \n')

if not os.path.exists('results'):
    os.makedirs('results')

model_without_ddp.seq_all_band_l = np.ones((7, 178, 22))
model_without_ddp.seq_all_band_r =  np.ones((7, 178, 22))

criterion = getattr(impm(cfg.train.criterion_file), 'Criterion')(cfg)

# build trainer
Trainer = getattr(impm(cfg.train.file), 'Trainer')(
    cfg,
    model,
    criterion=criterion,
    optimizer=None,
    lr_scheduler=None,
    logger=None,
    log_dir=cfg.log_dir,
    performance_indicator=cfg.pi,
    last_iter=-1,
    rank=0,
    device=device,
)

# nav map
figure = plt.figure()
fig = figure.add_subplot(1, 1, 1)

source_sound, sr = torchaudio.load(f'data/{music_type}.mp3')
print(sr, source_sound.shape)
if music_type == 'lecun':
    source_sound = source_sound.mean(0).numpy()[45*sr:105*sr]
else:
    source_sound = source_sound.mean(0).numpy()[0*sr:51*sr]

# plot bounces
with open(os.path.join(f'/Learning_Neural_Acoustic_Fields/metadata/room_grid_coors/{apt}.pkl'),"rb") as coor_file_obj:
    plot_points = list(pickle.load(coor_file_obj).values())
plot_points = np.array(plot_points)
plot_points[..., 1] *= -1 
fig.scatter(plot_points[:, 0], plot_points[:, 1],  c='0.8')


# set source from room3
if source_type == 'begin':
    source = np.array([arr[0][-1][0], -arr[0][-1][2]]) # begin
elif source_type == 'end':
    end_idx = len(arr) - 1
    source = np.array([arr[end_idx][-1][0], -arr[end_idx][-1][2]]) # end
else:
    mid_idx = int(len(arr) / 2)
    source = np.array([arr[mid_idx][-1][0], -arr[mid_idx][-1][2]]) # middle

# plot source
fig.scatter(source[0],  source[1], marker='*', c='r', s=10, label='Listener')
fig.scatter(source[0],  source[1], marker='>', c='g', s=200, label='Emitter')
plt.legend()

count = 0
audios = []
single_len = int(sr / 24)
final_audio = []
max_res = -1000
min_res = 1000

for k, key in tqdm(enumerate(keys)):
    angle, l_pos = arr[key][-2:]
    remain_angle = angle % 360
    if remain_angle > 0 or remain_angle <= 90:
        angle = 0
        angle1 = 1
    elif remain_angle > 90 or remain_angle <= 180:
        angle = 1
        angle1 = 2
    elif remain_angle > 180 or remain_angle <= 270:
        angle = 2
        angle1 = 3
    elif remain_angle > 270 or remain_angle <= 360:
        angle = 3
        angle1 = 0
    audio_start = (k * single_len) % len(source_sound)
    audio_end = ((k+1) * single_len) % len(source_sound)

    with torch.no_grad():
        source_points = torch.from_numpy(source).to(device).reshape(1, 1, -1)
        cur_points = torch.from_numpy(np.array([l_pos[0], -l_pos[2]]).reshape(-1, 2)).to(device).reshape(1, 1, -1)
        norm_source = ((source - min_pos) / (max_pos - min_pos) - 0.5) * 2.0
        norm_cur_point = ((np.array([l_pos[0], -l_pos[2]]) - min_pos) / (max_pos - min_pos) - 0.5) * 2.0
        norm_source_points = torch.from_numpy(norm_source).to(device).reshape(1, 1, -1)
        norm_cur_points = torch.from_numpy(norm_cur_point).to(device).reshape(1, 1, -1)
        dirs = torch.from_numpy(np.array(angle)).to(device).reshape(1, 1, -1)
        dirs1 = torch.from_numpy(np.array(angle1)).to(device).reshape(1, 1, -1)
        b_range = None
        
        # plot navpoint
        fig.scatter(cur_points[0, 0, 0].data.cpu().numpy(), cur_points[0, 0, 1].data.cpu().numpy(), marker='*', c='r', s=10)
        fig.axis('off')
        figure.savefig(f'{apt}_{music_type}_source_{source_type}_nav_temp2.jpg')
        plt.imsave(f'{apt}_{music_type}_source_{source_type}_tmp2.jpg', arr[key][0]['rgb'][..., :3])

        target = concat_images([f'{apt}_{music_type}_source_{source_type}_tmp2.jpg', f'{apt}_{music_type}_source_{source_type}_nav_temp2.jpg'])
        target.save(f'{apt}_{music_type}_source_{source_type}_tar2.jpg')
        target = plt.imread(f'{apt}_{music_type}_source_{source_type}_tar2.jpg')
        writer.append_data(np.array(target))

        pred_ir = Trainer.render.module.render(source_points, cur_points, norm_source_points, norm_cur_points, dirs, b_range)
        pred_ir1 = Trainer.render.module.render(source_points, cur_points, norm_source_points, norm_cur_points, dirs1, b_range)
        scale = (remain_angle - angle*90) / 90.
        pred_ir = scale * pred_ir + (1 - scale) * pred_ir1
        pred_ir = pred_ir.squeeze(0)
        out_wav = []

        for i, cur_pred_ir in enumerate(pred_ir):
            cur_pred_ir = cur_pred_ir.data.cpu().numpy()
            x_dense = np.arange(0, len(cur_pred_ir)*2)
            x_bin = np.arange(0, len(cur_pred_ir)*2, 2)
            f = interp1d(x_bin, cur_pred_ir, kind = 'slinear')
            np_pred_ir = f(x_dense[:len(x_bin)*2-2])
            res = oaconvolve(source_sound, np_pred_ir).real
            if res.max() > max_res:
                max_res = res.max()
            if res.min() < min_res:
                min_res = res.min()
            pred_wav = res[:len(source_sound)]
            if audio_end <= audio_start:
                pred_wav = np.hstack([pred_wav[audio_start:], pred_wav[:audio_end]])
            else:
                pred_wav = pred_wav[audio_start:audio_end]
            out_wav.append(pred_wav)

        cur_output = np.array(out_wav).reshape(2, -1).transpose(1, 0)
        final_audio.append(cur_output)
        count += 1
print(max_res, min_res)
writer.close()
video_clip = VideoFileClip(video_name)
final_audio = np.concatenate(final_audio)
final_audio = normalize(final_audio)
audio_file = f'results_wav/final_audio_{apt}_{music_type}_source_{source_type}.wav'
sf.write(audio_file, final_audio, sr)
audio_clip = AudioFileClip(audio_file)
final_clip = video_clip.set_audio(audio_clip)
final_clip.write_videofile(f'finaldemo/apart1_final_{apt}_{music_type}_source_{source_type}.mp4', fps=24)