import sys
import os
sys.path.insert(1,os.getcwd())
#from perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality
from private_test_scripts.perceivers.crossattnperceiver import MultiModalityPerceiver, InputModality
import torch
import argparse
from utils.tools import set_seed, count_param
from unimodals.common_models import Reshape

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.0008)
# parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--avmnist-path', type=str, default='datasets/data/avmnist')
parser.add_argument('--mosei-senti-path', type=str, default='datasets/data/mosei_senti_data-001.pkl')
parser.add_argument('--model-path', type=str, default='private_test_scripts/model/threetasks.pth')

parser.add_argument('--is-train', type=int, default=1)
parser.add_argument('--debug', type=int, default=1)

args = parser.parse_args()
print(args)

set_seed(args.seed)

is_debug = args.debug == 1
torch.multiprocessing.set_sharing_strategy('file_system')
from datasets.mimic.get_data import get_dataloader
trains1,valid1,test1=get_dataloader(7,imputed_path='/data/workspace/pengjie/MultiBench/datasets/data/im.pk',batch_size=20, debug=is_debug,hmmt=True)
test1 = test1['timeseries'][0]
from datasets.avmnist.get_data import get_dataloader
trains2,valid2,test2=get_dataloader(args.avmnist_path,unsqueeze_channel=False, cut_into=True, debug=is_debug)
from datasets.affect.get_data import get_simple_processed_data
trains3,valid3,test3=get_simple_processed_data(args.mosei_senti_path,fracs=1,repeats=1, debug=is_debug)
from private_test_scripts.perceivers.humorloader import get_dataloader
trains4,valid4,test4=get_dataloader(1,32,5, debug=is_debug)

for i in range(len(trains2.dataset[0])):
    print(trains2.dataset[0][i].shape)
    
for i in range(len(trains3.dataset[0])):
    print(trains3.dataset[0][i].shape if hasattr(trains3.dataset[0][i], 'shape') else trains3.dataset[0][i])
    
for i in range(len(trains4.dataset[0])):
    print(trains4.dataset[0][i].shape if hasattr(trains4.dataset[0][i], 'shape') else trains4.dataset[0][i])
# exit()
torch.set_num_threads(1)
device='cpu'
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    
    
static_modality=InputModality(
    name='static',
    input_channels=1,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)
timeseries_modality=InputModality(
    name='timeseries',
    input_channels=1,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
colorless_image_modality=InputModality(
    name='colorlessimage',
    input_channels=16,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
audio_spec_modality=InputModality(
    name='audiospec',
    input_channels=256,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)

feature1_modality=InputModality(
    name='feature1',
    input_channels=35,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature2_modality=InputModality(
    name='feature2',
    input_channels=74,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature3_modality=InputModality(
    name='feature3',
    input_channels=300,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature4_modality=InputModality(
    name='feature4',
    input_channels=371,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature5_modality=InputModality(
    name='feature5',
    input_channels=81,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
for i in range(1):
    #"""
    model = MultiModalityPerceiver(
        modalities=(colorless_image_modality,audio_spec_modality,feature1_modality,feature2_modality,feature3_modality,feature4_modality,feature5_modality, static_modality, timeseries_modality),
        depth=1,  # depth of net, combined with num_latent_blocks_per_layer to produce full Perceiver
        num_latents=20,
        # number of latents, or induced set points, or centroids. different papers giving it different names
        latent_dim=64,  # latent dimension
        cross_heads=1,  # number of heads for cross attention. paper said 1
        latent_heads=6,  # number of heads for latent self attention, 8
        cross_dim_head=64,
        latent_dim_head=64,
        num_classes=1,  # output number of classes
        attn_dropout=0.,
        ff_dropout=0.,
        #embed=True,
        weight_tie_layers=True,
        num_latent_blocks_per_layer=1, # Note that this parameter is 1 in the original Lucidrain implementation,
        cross_depth=1
    ).to(device)
    model.to_logitslist=torch.nn.ModuleList([torch.nn.Sequential(torch.nn.LayerNorm(128),torch.nn.Linear(128,2)),
                                             torch.nn.Sequential(torch.nn.LayerNorm(128),torch.nn.Linear(128,10)),
                                             torch.nn.Sequential(torch.nn.LayerNorm(384),torch.nn.Linear(384,2)),
                                             torch.nn.Sequential(torch.nn.LayerNorm(384),torch.nn.Linear(384,2))]).to(device)

    from private_test_scripts.perceivers.train_structure_multitask import train

    # print(count_param(model))
    
    records=train(model,args.epochs,[trains1, trains2,trains3,trains4],
                  [valid1, valid2,valid3,valid4],
                  [test1, test2,test3,test4],
                  [['static', 'timeseries'],['colorlessimage','audiospec'],['feature1','feature2','feature3'],['feature4','feature5','feature3']],
                  args.model_path,lr=args.lr,device=device,
                  criterions=[torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss()], 
                  train_weights=[1.2, 0.9,1.1,1.0],
                  is_affect=[False,False,False,False],
                  unsqueezing=[False,False,False,False],
                  transpose=[False,False,False,False],
                  evalweights=[1,1,1,1],start_from=0,
                  is_classification=[True,True, True, True],
                  weight_decay=0.001, is_train = args.is_train == 1, calc_flops=False)
