import sys
import os
import torch
sys.path.append(os.getcwd())
# from perceiver.perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality
from private_test_scripts.perceivers.crossattnperceiver import MultiModalityPerceiver, InputModality, MultiModalityPerceiverFaster
import argparse
from utils.tools import set_seed, create_dir, count_param
from custom_model.common_models import Reshape
from custom_model.iclr24_model import MultiModalityMoETransformer, MultiModalityConfig, MultiModalitySequenceMoETransformer, GATES, SM3TaskHeads

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.0008)
# parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--avmnist-path', type=str, default='datasets/data/avmnist')
parser.add_argument('--mosei-senti-path', type=str, default='datasets/data/mosei_senti_data-001.pkl')
parser.add_argument('--model-path', type=str, default='private_test_scripts/model/threetasks.pth')
parser.add_argument('--img-path', type=str, default='log/img')
parser.add_argument('--unlimited-capacity-on-mlp', type=int, default=1)
parser.add_argument('--seperate-qkv', type=int, default=1)
parser.add_argument('--co-input', type=int, default=1)

parser.add_argument("--no_vision", action="store_true")
parser.add_argument("--no_proprioception", action="store_true")
parser.add_argument("--no_haptics", action="store_true")
parser.add_argument("--image_blackout_ratio", type=float, default=0.0)
parser.add_argument("--sequential_image_rate", type=int, default=1)
parser.add_argument("--kloss_dataset", action="store_true")
parser.add_argument('--is-train', type=int, default=1)
parser.add_argument('--gate-type', type=str, default='NoisyGate', choices=GATES.keys())

parser.add_argument('--modality-gating-merge', type=int, default=0)
parser.add_argument('--training-weight', type=float, default=[0.9, 1.1, 1.5], nargs='+')
parser.add_argument('--dynamic-reweight', type=int, default=0)
parser.add_argument('--cross-modality-attn', type=int, default=0)
parser.add_argument('--cross-depth', type=int, default=1)
parser.add_argument('--grad-clip', type=int, default=0)
parser.add_argument('--grad-clip-value', type=float, default=1.)
parser.add_argument("--drop-rate", type=float, default=0.0)
parser.add_argument('--moe-gate-weight', type=float, default=0.1)
parser.add_argument('--tune-gate-weight', type=int, default=0)
parser.add_argument('--mlp-top-k', type=int, default=2)
parser.add_argument('--attn-top-k', type=int, default=2)
parser.add_argument('--weight-decay', type=float, default=0.001)

parser.add_argument('--task-gating-merge', type=int, default=0)
parser.add_argument('--modality-joint', type=int, default=0)
parser.add_argument('--attn-modality-specific', type=int, default=1)
parser.add_argument('--mlp-modality-specific', type=int, default=0)

parser.add_argument('--debug', type=int, default=1)

args = parser.parse_args()
print(args)
set_seed(args.seed)
create_dir(args.img_path)

tune_gate_weight = args.tune_gate_weight == 1
attn_modality_specific = args.attn_modality_specific == 1
mlp_modality_specific = args.mlp_modality_specific == 1

is_debug = args.debug == 1

torch.multiprocessing.set_sharing_strategy('file_system')
# from datasets.mimic.get_data import get_dataloader
# trains1,valid1,test1=get_dataloader(7,imputed_path='/home/pliang/yiwei/im.pk',no_robust=True,batch_size=20,fracs=1)
from datasets.avmnist.get_data import get_dataloader
trains1,valid1,test1=get_dataloader(args.avmnist_path,unsqueeze_channel=False, cut_into=True, num_workers=4, debug=is_debug)
print("Load avmnist finished.")

from datasets.affect.get_data import get_simple_processed_data
trains2,valid2,test2=get_simple_processed_data(args.mosei_senti_path,fracs=1,repeats=1, num_workers=8, debug=is_debug)

print("Load mosei finished.")
from private_test_scripts.perceivers.humorloader import get_dataloader
trains3,valid3,test3 = get_dataloader(1,32,5, num_workers=8, debug=is_debug)
print("Load humor finished.")

from datasets.mimic.get_data import get_dataloader
trains4,valid4,test4=get_dataloader(7,imputed_path='/data/workspace/pengjie/MultiBench/datasets/data/im.pk',batch_size=20, debug=is_debug)
test4 = test4['timeseries'][0]
print("Load mimiciii finised.")

torch.set_num_threads(1)

device='cpu'
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    
print(device)
static_modality=InputModality(
    name='static',
    input_channels=1,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)
timeseries_modality=InputModality(
    name='timeseries',
    input_channels=12,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
colorless_image_modality=InputModality(
    name='colorlessimage',
    input_channels=16,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
audio_spec_modality=InputModality(
    name='audiospec',
    input_channels=256,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)

feature1_modality=InputModality(
    name='feature1',
    input_channels=35,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature2_modality=InputModality(
    name='feature2',
    input_channels=74,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature3_modality=InputModality(
    name='feature3',
    input_channels=300,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature4_modality=InputModality(
    name='feature4',
    input_channels=371,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)
feature5_modality=InputModality(
    name='feature5',
    input_channels=81,
    input_axis=1,
    num_freq_bands=3,
    max_freq=1
)

mconfig = MultiModalityConfig(capacity_per_expert = 86, seed = args.seed, 
                              img_path = args.img_path, 
                              unlimited_capacity_on_mlp = args.unlimited_capacity_on_mlp == 1, 
                              num_tasks = 4, 
                              seperate_qkv = args.seperate_qkv == 1,
                              co_input = args.co_input == 1, 
                              base_capacity = 32, gate = GATES[args.gate_type],
                              modalities_name = ['colorlessimage','audiospec','feature1','feature2','feature3','feature4','feature5','feature3', 'static', 'timeseries'],
                              attn_modality_specific = attn_modality_specific,
                              mlp_modality_specific = mlp_modality_specific,
                              modality_gating_merge = args.modality_gating_merge == 1,
                              capacity_ratio = 1.0,
                              dynamic_reweight = args.dynamic_reweight == 1,
                              cross_modality_attn = args.cross_modality_attn == 1,
                              mlp_top_k = args.mlp_top_k,
                              attn_top_k = args.attn_top_k,
                              grad_clip_value = args.grad_clip_value,
                              task_gating_merge = args.task_gating_merge == 1,
                              modality_joint = args.modality_joint == 1)

if mconfig.modality_gating_merge:
    if mconfig.modality_joint:
        mconfig.setting_modality_remap(
            {
            'colorlessimage': "image",
            'audiospec': 'image',
            'feature1': 'image',
            'feature2': 'image',
            'feature3': 'image',
            'feature4': 'image',
            'feature5': 'image',
            'feature3': 'image',
            'static': 'image',
            'timeseries': 'image'
            }
        )
    else:
        mconfig.setting_modality_remap(
            {
                'colorlessimage': "image",
                'audiospec': 'audio',
                'feature1': 'image',
                'feature2': 'audio',
                'feature3': 'text',
                'feature4': 'image',
                'feature5': 'audio',
                'feature3': 'text',
                'static': 'static',
                'timeseries': 'timeseries'
            }
        )

if mconfig.task_gating_merge:
    mconfig.setting_task_remap(
        {
            '0':'0',
            '1':'0',
            '2':'0',
            '3':'0',
        }
    )



model = MultiModalitySequenceMoETransformer(
    modalities=(colorless_image_modality,audio_spec_modality,feature1_modality,feature2_modality,feature3_modality,feature4_modality,feature5_modality, static_modality, timeseries_modality),
    depth=1,  # depth of net, combined with num_latent_blocks_per_layer to produce full Perceiver
    cross_heads=1,  # number of heads for cross attention. paper said 1
    num_latents=12,
    # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim=64,  # latent dimension
    latent_heads=8,  # number of heads for latent self attention, 8
    cross_dim_head=64,
    latent_dim_head=64,
    num_classes=2,  # output number of classes
    attn_dropout=args.drop_rate,
    ff_dropout=args.drop_rate,
    cross_depth=args.cross_depth,
    weight_tie_layers=True,
    modalities_num = [2, 3, 3, 2],
    num_latent_blocks_per_layer=1,  # Note that this parameter is 1 in the original Lucidrain implementation
    # whether to weight tie layers (optional, as indicated in the diagram)
    args = mconfig
).to(device)

# model.to_logitslist=torch.nn.ModuleList([torch.nn.Sequential(torch.nn.LayerNorm(128),torch.nn.Linear(128,10)).to(device),
#                                          torch.nn.Sequential(torch.nn.LayerNorm(64 * 3),torch.nn.Linear(64 * 3,2)).to(device),
#                                          torch.nn.Sequential(torch.nn.LayerNorm(64 * 3),torch.nn.Linear(64 * 3,2)).to(device),
#                                          torch.nn.Sequential(torch.nn.LayerNorm(64 * 2),torch.nn.Linear(64 * 2,2)).to(device)])

model.to_logitslist=torch.nn.ModuleList(
    [
        SM3TaskHeads('av_mnist', ['colorlessimage','audiospec'], 32, device).to(device),
        SM3TaskHeads('mosei', ['feature1','feature2','feature3'], 32, device).to(device),
        SM3TaskHeads('humor', ['feature4','feature5','feature3'], 32, device).to(device),
        SM3TaskHeads('mimic', ['static', 'timeseries'], 32, device).to(device)
    ]
)

print(count_param(model))
# from private_test_scripts.perceivers.train_structure_multitask import train
from private_test_scripts.mmoe_transformer.multitask_training import train


train(model, args.epochs, [trains1, trains2, trains3, trains4], [valid1, valid2, valid3, valid4], [test1, test2, test3, test4],\
    [['colorlessimage','audiospec'],['feature1','feature2','feature3'],['feature4','feature5','feature3'], ['static', 'timeseries']],\
    args.model_path,lr=args.lr,device=device,train_weights=args.training_weight,
    eval_weights=[1.0, 1.0, 1.0, 1.0],is_classification=[True, True, True, True],
    criterions=[torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss()], 
    is_affect=[False,False,False, False], 
    unsqueezing=[False, False, False, False],transpose=[False, False, False, False],
    weight_decay=args.weight_decay,
    is_train=args.is_train == 1, 
    args = args, mconfig = mconfig, 
    calc_flops=is_debug, 
    tune_gate_weight=tune_gate_weight,
    moe_gate_weight=[args.moe_gate_weight] * 4,
    grad_clip=args.grad_clip == 1)