import sys
import os
import copy
import torch
sys.path.append(os.getcwd())
# from perceiver.perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality
from private_test_scripts.perceivers.crossattnperceiver import MultiModalityPerceiver, InputModality, MultiModalityPerceiverFaster
import argparse
from utils.tools import set_seed, create_dir, count_param
from custom_model.common_models import Reshape
from custom_model.iclr24_model import MultiModalityMoETransformer, MultiModalityConfig, MultiModalitySequenceMoETransformer, GATES, SM3TaskHeads

from datasets.robotics.get_data import get_data as robotics_get_data
from examples.robotics.robotics_utils import set_seeds as robotics_set_seeds
import yaml
from datasets.gentle_push.data_loader import PushTask
import fannypack
from utils.tools import dataloader_info

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.0005)
# parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--gentle-push-path', type=str, default='datasets/gentle_push/cache')
parser.add_argument('--vt-yml-path', type=str, default='private_test_scripts/perceivers/robotics_training_default.yaml')
parser.add_argument('--model-path', type=str, default='private_test_scripts/model/robotics.pth')
parser.add_argument('--img-path', type=str, default='log/img')
parser.add_argument('--unlimited-capacity-on-mlp', type=int, default=1)
parser.add_argument('--co-input', type=int, default=1)
parser.add_argument('--seperate-qkv', type=int, default=1)

parser.add_argument("--no_vision", action="store_true")
parser.add_argument("--no_proprioception", action="store_true")
parser.add_argument("--no_haptics", action="store_true")
parser.add_argument("--image_blackout_ratio", type=float, default=0.0)
parser.add_argument("--sequential_image_rate", type=int, default=1)
parser.add_argument("--kloss_dataset", action="store_true")
parser.add_argument('--is-train', type=int, default=1)
parser.add_argument('--gate-type', type=str, default='NoisyGate', choices=GATES.keys())

parser.add_argument('--modality-gating-merge', type=int, default=0)
parser.add_argument('--training-weight', type=float, default=[0.9, 1.1, 1.5], nargs='+')
parser.add_argument('--dynamic-reweight', type=int, default=0)
parser.add_argument('--cross-modality-attn', type=int, default=0)
parser.add_argument('--cross-depth', type=int, default=1)
parser.add_argument('--grad-clip', type=int, default=0)
parser.add_argument("--drop-rate", type=float, default=0.0)
parser.add_argument('--tune-gate-weight', type=int, default=0)
parser.add_argument('--moe-gate-weight', type=float, default=0.1)

parser.add_argument('--debug', type=int, default=1)

parser.add_argument('--task-gating-merge', type=int, default=0)
parser.add_argument('--modality-joint', type=int, default=0)

args = parser.parse_args()
print(args)

tune_gate_weight = args.tune_gate_weight == 1

torch.set_num_threads(1)
set_seed(args.seed)
create_dir(args.img_path)
is_debug = args.debug == 1

device='cpu'
if torch.cuda.is_available():
    device = torch.cuda.current_device()

torch.multiprocessing.set_sharing_strategy('file_system')
# from datasets.mimic.get_data import get_dataloader
# trains1,valid1,test1=get_dataloader(7,imputed_path='/home/pliang/yiwei/im.pk',no_robust=True,batch_size=20,fracs=1)
fannypack.data.set_cache_path(args.gentle_push_path)
trains1, valid1, test1 = PushTask.get_dataloader(16, batch_size=18, drop_last=True, test_multimodal_only=True, test_noises=[0], debug=is_debug, num_workers=8, cut_into=False)
test1 = test1['multimodal'][0]


with open(args.vt_yml_path) as f:
    configs = yaml.load(f, Loader = yaml.FullLoader)
robotics_set_seeds(args.seed, True)
trains2, valid2 = robotics_get_data(device, configs, '', debug=is_debug)
test2 = copy.deepcopy(valid2) # No test data for this dataset


# define your modalities (same way as regular perceiver)
static_modality=InputModality(
    name='static',
    input_channels=1,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)
timeseries_modality=InputModality(
    name='timeseries',
    input_channels=1,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
colorless_image_modality=InputModality(
    name='colorlessimage',
    input_channels=1,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
audio_spec_modality=InputModality(
    name='audiospec',
    input_channels=1,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
timeseries_gripper_pos_modality=InputModality(
    name='timeseries_gripper_pos',
    input_channels=3,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)
timeseries_gripper_sensors_modality=InputModality(
    name='timeseries_gripper_sensors',
    input_channels=7,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)
timeseries_control_modality=InputModality(
    name='timeseries_control',
    input_channels=7,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)
colorless_image_timeseries_modality=InputModality(
    name='colorlessimage_timeseries',
    input_channels=1,
    input_axis=3,
    num_freq_bands=6,
    max_freq=1
)
image_modality=InputModality(
    name='image',
    input_channels=3,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
force_modality=InputModality(
    name='force',
    input_channels=32,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)
proprio_modality=InputModality(
    name='proprio',
    input_channels=8,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)

depth_modality=InputModality(
    name='depth',
    input_channels=1,
    input_axis=2,
    num_freq_bands=6,
    max_freq=1
)
action_modality=InputModality(
    name='action',
    input_channels=4,
    input_axis=1,
    num_freq_bands=6,
    max_freq=1
)

mconfig = MultiModalityConfig(capacity_per_expert = 86, seed = args.seed, 
                              img_path = args.img_path, 
                              unlimited_capacity_on_mlp = args.unlimited_capacity_on_mlp == 1, num_tasks = 2,
                              seperate_qkv = args.seperate_qkv == 1,
                              co_input = args.co_input == 1,
                              base_capacity = 32, 
                              gate = GATES[args.gate_type],
                              modalities_name = ['timeseries_gripper_pos','timeseries_gripper_sensors','colorlessimage_timeseries','timeseries_control','image','force','proprio','depth', 'action'],
                              attn_modality_specific = True,
                              modality_gating_merge = args.modality_gating_merge == 1,
                              capacity_ratio = 1.0,
                              dynamic_reweight = args.dynamic_reweight == 1,
                              cross_modality_attn = args.cross_modality_attn == 1,
                              task_gating_merge = args.task_gating_merge == 1,
                              modality_joint = args.modality_joint == 1)

if mconfig.modality_gating_merge:
    if mconfig.modality_joint:
        mconfig.setting_modality_remap(
            {
                'timeseries_gripper_pos': "image",
                'timeseries_gripper_sensors': 'image',
                'colorlessimage_timeseries': 'image',
                'timeseries_control': 'image',
                'image': 'image',
                'force': 'image',
                'proprio': 'image',
                'depth': 'image',
                'action': 'image'
            }
        )
    else:
        mconfig.setting_modality_remap(
            {
                'timeseries_gripper_pos': "gripper_pos",
                'timeseries_gripper_sensors': 'gripper_sensors',
                'colorlessimage_timeseries': 'image',
                'timeseries_control': 'timeseries_control',
                'image': 'image',
                'force': 'force',
                'proprio': 'proprio',
                'depth': 'depth',
                'action': 'action'
            }
        )

if mconfig.task_gating_merge:
    mconfig.setting_task_remap(
        {
            '0':'0',
            '1':'0'
        }
    )

model = MultiModalitySequenceMoETransformer(
    modalities=(static_modality,timeseries_modality,colorless_image_modality,audio_spec_modality,
                timeseries_gripper_pos_modality,timeseries_gripper_sensors_modality,
                timeseries_control_modality,colorless_image_timeseries_modality,
                image_modality, force_modality, proprio_modality, depth_modality, action_modality,),
    depth=1,  # depth of net, combined with num_latent_blocks_per_layer to produce full Perceiver
    cross_heads=1,  # number of heads for cross attention. paper said 1
    num_latents=20,
    # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim=64,  # latent dimension
    latent_heads=8,  # number of heads for latent self attention, 8
    cross_dim_head=64,
    latent_dim_head=64,
    num_classes=2,  # output number of classes
    attn_dropout=args.drop_rate,
    ff_dropout=args.drop_rate,
    cross_depth=args.cross_depth,
    weight_tie_layers=True,
    modalities_num = [4, 5],
    num_latent_blocks_per_layer=1,  # Note that this parameter is 1 in the original Lucidrain implementation
    # whether to weight tie layers (optional, as indicated in the diagram)
    args = mconfig
).to(device)

# model.to_logitslist=torch.nn.ModuleList([torch.nn.Sequential(torch.nn.LayerNorm(64 * 4),torch.nn.Linear(64 * 4, 16 * 2),Reshape([-1, 16, 2])).to(device), 
#                                          torch.nn.Sequential(torch.nn.LayerNorm(64 * 5),torch.nn.Linear(64 * 5,2)).to(device)])

model.to_logitslist=torch.nn.ModuleList(
    [
        SM3TaskHeads('push', ['timeseries_gripper_pos','timeseries_gripper_sensors','colorlessimage_timeseries','timeseries_control'], 16, device).to(device),
        SM3TaskHeads('vt', ['image','force','proprio','depth', 'action'], 16, device).to(device)
    ]
)

# from private_test_scripts.perceivers.train_structure_multitask import train
from private_test_scripts.mmoe_transformer.multitask_training import train



def encoder_fn(x):
    if 'proprio' in x:
        x['proprio'] = x['proprio'].unsqueeze(1)
    if 'action' in x:
        x['action'] = x['action'].unsqueeze(1)
    return x

encoder = [encoder_fn for _ in range(2)]

# print(count_param(model))

train(model, args.epochs, [trains1, trains2], [valid1, valid2], [test1, test2],\
    [['timeseries_gripper_pos','timeseries_gripper_sensors','colorlessimage_timeseries','timeseries_control'],
     ['image','force','proprio','depth', 'action']],\
    args.model_path,lr=args.lr,device=device,train_weights=args.training_weight,encoder = encoder,
    eval_weights=[100.0, 1.0],is_classification=[False, True],criterions=[torch.nn.MSELoss(), torch.nn.CrossEntropyLoss()], 
    is_affect=[False,False], unsqueezing=[False, False],transpose=[False, False],weight_decay=0.0,
    is_train=args.is_train == 1, 
    args = args, mconfig=mconfig,
    calc_flops=False, 
    grad_clip=args.grad_clip == 1,
    tune_gate_weight=tune_gate_weight,
    moe_gate_weight=[args.moe_gate_weight] * 2)
