import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from models.GTDM_Model import Conv_GTDM_Controller
from PickleDataset import PickleDataset, transform_data
from video_generator import VideoGenerator
from cache_datasets import cache_data
import configargparse
import time
import matplotlib.pyplot as plt
'''

This file is purely for visualization purposes, we get to see what the data in test looks like and the allocation among layers generated by the controller


'''
def computeDist(tensor1, tensor2):
    tensor1 = torch.squeeze(tensor1)
    tensor2 = torch.squeeze(tensor2)
    distance = 0.0
    for i in range(len(tensor1)):
        distance += (tensor1[i] - tensor2[i]) ** 2
    return distance ** 0.5

def get_args_parser():
    parser = configargparse.ArgumentParser(description='GTDM Controller Testing, load config file and override params',
                                           default_config_files=['./configs/configs.yaml'], config_file_parser_class=configargparse.YAMLConfigFileParser)
    # Define the parameters with their default values and types
    parser.add("--base_root", type=str, help="Base directory for datasets")
    parser.add("--cache_dir", type=str, help="Directory to cache datasets")
    parser.add("--valid_mods", type=str, nargs="+", help="List of valid modalities")
    parser.add("--valid_nodes", type=int, nargs="+", help="List of valid nodes")
    parser.add("--learning_rate", type=float, default=1e-6, help="Learning rate for training")
    parser.add("--num_epochs", type=int, default=200, help="Number of epochs to train")
    parser.add("--adapter_hidden_dim", type=int, default=512, help="Dimension of adapter hidden layers")
    parser.add("--batch_size", type=int, default=32, help="Batch size for training")
    parser.add("--save_best_model", type=bool, default=True, help="Save the best model")
    parser.add("--save_every_X_model", type=int, default=5, help="Save model every X epochs")
    parser.add('--total_layers', type=int, default=8, help="How many layers to reduce to")
    parser.add('--seedVal', type=int, default=100, help="Seed for training")
    parser.add('--folder', type=str, default='./logs', help='Folder containing the model')
    parser.add('--checkpoint', type=str, default='last.pt', help="ckpt nane")
    parser.add('--drop_layers_img', type=int, nargs="+", help='List of layers to DROP')
    parser.add('--drop_layers_depth', type=int, nargs="+", help='List of layers to DROP')
    parser.add('--test_type', type=str, default='continuous', choices=['continuous', 'discrete', 'finite'])

    # Parse arguments from the configuration file and command-line
    args = parser.parse_args()
    data_root = args.base_root + '/train'
    args.trainset = [
        f'{data_root}/mocap.hdf5',
        f'{data_root}/node_1/mmwave.hdf5',
        f'{data_root}/node_2/mmwave.hdf5',
        f'{data_root}/node_3/mmwave.hdf5',
        f'{data_root}/node_1/realsense.hdf5',
        f'{data_root}/node_2/realsense.hdf5',
        f'{data_root}/node_3/realsense.hdf5',
        f'{data_root}/node_1/respeaker.hdf5',
        f'{data_root}/node_2/respeaker.hdf5',
        f'{data_root}/node_3/respeaker.hdf5',
        f'{data_root}/node_1/zed.hdf5',
        f'{data_root}/node_2/zed.hdf5',
        f'{data_root}/node_3/zed.hdf5',
    ]
    data_root = args.base_root + '/val'
    args.valset = [
        f'{data_root}/mocap.hdf5',
        f'{data_root}/node_1/mmwave.hdf5',
        f'{data_root}/node_2/mmwave.hdf5',
        f'{data_root}/node_3/mmwave.hdf5',
        f'{data_root}/node_1/realsense.hdf5',
        f'{data_root}/node_2/realsense.hdf5',
        f'{data_root}/node_3/realsense.hdf5',
        f'{data_root}/node_1/respeaker.hdf5',
        f'{data_root}/node_2/respeaker.hdf5',
        f'{data_root}/node_3/respeaker.hdf5',
        f'{data_root}/node_1/zed.hdf5',
        f'{data_root}/node_2/zed.hdf5',
        f'{data_root}/node_3/zed.hdf5',
    ]
    data_root = args.base_root + '/test'
    args.testset = [
        f'{data_root}/mocap.hdf5',
        f'{data_root}/node_1/mmwave.hdf5',
        f'{data_root}/node_2/mmwave.hdf5',
        f'{data_root}/node_3/mmwave.hdf5',
        f'{data_root}/node_1/realsense.hdf5',
        f'{data_root}/node_2/realsense.hdf5',
        f'{data_root}/node_3/realsense.hdf5',
        f'{data_root}/node_1/respeaker.hdf5',
        f'{data_root}/node_2/respeaker.hdf5',
        f'{data_root}/node_3/respeaker.hdf5',
        f'{data_root}/node_1/zed.hdf5',
        f'{data_root}/node_2/zed.hdf5',
        f'{data_root}/node_3/zed.hdf5',
    ]


    return args


def main(args):
    folder = str(args.folder)
    cache_data(args)
    #import pdb; pdb.set_trace()
    # Point test.py to appropriate log folder containing the saved model weights
    dir_path = folder + '/'
    # Create model architecture
    model = Conv_GTDM_Controller(args.adapter_hidden_dim, valid_mods=args.valid_mods, valid_nodes = args.valid_nodes, total_layers=args.total_layers) # Pass valid mods, nodes, and also hidden layer size
    # Load model weights
    model.load_state_dict(torch.load(dir_path + str(args.checkpoint)), strict=True)
    model.eval() # Set model to eval mode for dropout
    # Create dataset and dataloader for test
    testset = PickleDataset('/mnt/hdd1/redacted/data/Noisy_Dataset_Cache/Blur/test', args.valid_mods, args.valid_nodes)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    test_dataloader = DataLoader(testset, batch_size = args.batch_size, shuffle=False)
    # Initialize the kalman filter
    # kf = TorchMultiObsKalmanFilter(dt=1, std_acc=1)
    # outputs = {}
    # outputs['det_means'] = []
    # outputs['det_covs'] = []
    # outputs['track_means'] = []
    # outputs['track_covs'] = []
    # total_nll_loss = 0.0
    # total_mse_loss = 0.0
    # average_dist = 0.0
    # mseloss = nn.MSELoss()
    # mse_arr = []
    # gt_pos_arr = []
    # avg_distance_KF = 0.0
    # total_model_time = 0
    idx = 0
    for batch in tqdm(test_dataloader, desc = 'Computing test loss', leave=False):
        idx += 1
        # if idx % 10 != 0:
        #     continue
        with torch.no_grad():

            fig, axes = plt.subplots(1, 2, figsize=(10, 5))
            data, gt_pos = batch['data'], batch['gt_pos']
            orig_data = data
            data = transform_data(data)

            results, _, _ = model(data) # Evaluate on test data

            print("Predicted Results: ", results['early_fusion']['pred_mean'])
            
            axes[0].imshow(orig_data['zed_camera_left', 'node_3'][0].permute(1, 2, 0).cpu().numpy())
            axes[0].axis('off')
            # axes[0][1].imshow(orig_data['zed_camera_left', 'node_2'][0].permute(1, 2, 0).cpu().numpy())
            # axes[0][1].axis('off')
            # axes[0][2].imshow(orig_data['zed_camera_left', 'node_3'][0].permute(1, 2, 0).cpu().numpy())
            # axes[0][2].axis('off')

            axes[1].imshow(data['realsense_camera_depth', 'node_3'][0].permute(1, 2, 0).cpu().numpy())
            axes[1].axis('off')
            # axes[1][1].imshow(data['realsense_camera_depth', 'node_2'][0].permute(1, 2, 0).cpu().numpy())
            # axes[1][1].axis('off')
            # axes[1][2].imshow(data['realsense_camera_depth', 'node_3'][0].permute(1, 2, 0).cpu().numpy())
            # axes[1][2].axis('off')
            

            # gt_pos = gt_pos.to(device)[:, 0]
            # data, gt_noise = transform_noise(data, args.batch_size, img_std_max=0, depth_std_max=0)
            # start = time.time()
            # print(gt_noise)
            # results = model(data) # Evaluate on test data
            # total_model_time += time.time() - start
            
            # axes[0][0].imshow(data['zed_camera_left', 'node_1'][0].permute(1, 2, 0).cpu().numpy())
            # axes[0][1].imshow(data['zed_camera_left', 'node_2'][0].permute(1, 2, 0).cpu().numpy())
            # axes[0][2].imshow(data['zed_camera_left', 'node_3'][0].permute(1, 2, 0).cpu().numpy())
            # axes[1][0].imshow(data['realsense_camera_depth', 'node_1'][0].permute(1, 2, 0).cpu().numpy())
            # axes[1][1].imshow(data['realsense_camera_depth', 'node_2'][0].permute(1, 2, 0).cpu().numpy())
            # axes[1][2].imshow(data['realsense_camera_depth', 'node_3'][0].permute(1, 2, 0).cpu().numpy())
            # print(gt_pos[0])
            plt.show()
           
                   


    
    

if __name__ == '__main__':
    args = get_args_parser()
    main(args)

