import sys
import os
import torch
sys.path.append(os.getcwd())
# from perceiver.perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality
from private_test_scripts.perceivers.crossattnperceiver import MultiModalityPerceiver, InputModality, MultiModalityPerceiverFaster
import argparse
from utils.tools import set_seed
from unimodals.common_models import Reshape

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.001)
# parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--enrico-path', type=str, default='datasets/data/enrico')
parser.add_argument('--avmnist-path', type=str, default='datasets/data/avmnist')
parser.add_argument('--gentle-push-path', type=str, default='datasets/gentle_push/cache')
parser.add_argument('--model-path', type=str, default='private_test_scripts/model/medium.pth')

parser.add_argument("--no_vision", action="store_true")
parser.add_argument("--no_proprioception", action="store_true")
parser.add_argument("--no_haptics", action="store_true")
parser.add_argument("--image_blackout_ratio", type=float, default=0.0)
parser.add_argument("--sequential_image_rate", type=int, default=1)
parser.add_argument("--kloss_dataset", action="store_true")
parser.add_argument('--is-train', type=int, default=1)

args = parser.parse_args()
print(args)
torch.set_num_threads(1)
set_seed(args.seed)

torch.multiprocessing.set_sharing_strategy('file_system')

isdebug=False

# enrico
from datasets.enrico.get_data import get_dataloader
dls, weights = get_dataloader(args.enrico_path, batch_size=32, num_workers=8)
trains1, valid1, test1 = dls
test1 = test1['image'][0]
for i in range(len(trains1.dataset[0])):
    print(trains1.dataset[0][i].shape if hasattr(trains1.dataset[0][i], 'shape') else trains1.dataset[0][i])
print("loaded enrico dataset successfully!!!")

# print(test1)
# exit()

# avmnist
from datasets.avmnist.get_data import get_dataloader
trains2,valid2,test2=get_dataloader(args.avmnist_path,flatten_audio=True, unsqueeze_channel=0, batch_size=32, cut_into=True, num_workers=8, debug=isdebug, hmmt=True)
print("loaded avmnist dataset successfully!!!")
for i in range(len(trains2.dataset[0])):
    print(trains2.dataset[0][i].shape if hasattr(trains2.dataset[0][i], 'shape') else trains2.dataset[0])

# exit()
from datasets.gentle_push.data_loader import PushTask
import argparse
import fannypack

Task = PushTask
# Parse args
# parser = argparse.ArgumentParser()
# Task.add_dataset_arguments(parser)
# args = parser.parse_args()
dataset_args = Task.get_dataset_args(args)

fannypack.data.set_cache_path(args.gentle_push_path)

# trains3,valid3,test3 = Task.get_dataloader(16, batch_size=32, drop_last=True)
# test3 = test3['image'][0]
# print("loaded gentle_push dataset successfully!!!")


trains3,valid3,test3 = Task.get_dataloader(16, batch_size=32, drop_last=True, test_multimodal_only=True, num_workers=8, debug=isdebug, test_noises=[0])
# test3 = test3['image'][0]
test3 = test3['multimodal'][0]
print("loaded gentle_push dataset successfully!!!")

for i in range(len(trains3.dataset[0])):
    print(trains3.dataset[0][i].shape if hasattr(trains3.dataset[0][i], 'shape') else trains3.dataset[0][i])

# PushTask.get_dataloader(16, batch_size=18, drop_last=True, test_multimodal_only=True, test_noises=[0])

device='cpu'
if torch.cuda.is_available():
    device = torch.cuda.current_device()

# gentle_push modality
# [B, 16, 3, 1]
pose_modality = InputModality(
    name='pose',
    input_channels=3,  # number of channels for each token of the input
    input_axis=1,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('pose', pose_modality.input_dim)
# [B, 16, 7, 1]
sensor_modality = InputModality(
    name='sensor',
    input_channels=7,  # number of channels for mono audio
    input_axis=1,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('sensor', pose_modality.input_dim)
# [B, 16, 32, 32, 1]
trajectory_modality = InputModality(
    name='trajectory',
    input_channels=1,  # number of channels for each token of the input
    input_axis=3,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('trajectory', trajectory_modality.input_dim)
# [B, 16, 7, 1]
control_modality = InputModality(
    name='control',
    input_channels=7,  # number of channels for each token of the input
    input_axis=1,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('control', control_modality.input_dim)
# enrico modality
image_1_modality = InputModality(
    name='image_1',
    input_channels=3,  # number of channels for each token of the input
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('image_1', image_1_modality.input_dim)
image_2_modality = InputModality(
    name='image_2',
    input_channels=3,  # number of channels for mono audio
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('image_2', image_2_modality.input_dim)
image_modality = InputModality(
    name='image',
    input_channels=16,  # number of channels for each token of the input
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('image', image_modality.input_dim)
audio_modality = InputModality(
    name='audio',
    input_channels=256,  # number of channels for mono audio
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands = 6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('audio', audio_modality.input_dim)

model = MultiModalityPerceiverFaster(
    modalities=(image_1_modality,image_2_modality,image_modality,audio_modality,pose_modality,sensor_modality,trajectory_modality,control_modality),
    depth=1,  # depth of net, combined with num_latent_blocks_per_layer to produce full Perceiver
    cross_heads=1,  # number of heads for cross attention. paper said 1
    num_latents=12,
    # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim=64,  # latent dimension
    latent_heads=8,  # number of heads for latent self attention, 8
    cross_dim_head=64,
    latent_dim_head=64,
    num_classes=2,  # output number of classes
    attn_dropout=0.,
    ff_dropout=0.,
    cross_depth = 1,
    weight_tie_layers=True,
    num_latent_blocks_per_layer=1  # Note that this parameter is 1 in the original Lucidrain implementation
    # whether to weight tie layers (optional, as indicated in the diagram)
).to(device)

model.to_logitslist=torch.nn.ModuleList([torch.nn.Sequential(torch.nn.LayerNorm(128),torch.nn.Linear(128,20)).to(device),
                                         torch.nn.Sequential(torch.nn.LayerNorm(128),torch.nn.Linear(128,10)).to(device),
                                         torch.nn.Sequential(torch.nn.LayerNorm(768),torch.nn.Linear(768,16 * 2), Reshape([-1, 16, 2])).to(device)])

from private_test_scripts.perceivers.train_structure_multitask import train

train(model,args.epochs,[trains1,trains2,trains3],[valid1,valid2,valid3],[test1,test2,test3],\
    [['image_1','image_2'],['image','audio'],['pose', 'sensor', 'trajectory', 'control']],\
    args.model_path,lr=args.lr,device=device,train_weights=[0.8,1.0,1.1],
    eval_weights=[1.0, 1.0, 1.0],
    is_affect=[False,False,False], unsqueezing=[False,False,False],transpose=[False,False,False],
    is_train=True, calc_flops=False)
