from email.policy import default
import sys
import os
from telnetlib import GA
import torch
sys.path.append(os.getcwd())
from private_test_scripts.perceivers.crossattnperceiver import MultiModalityPerceiver, InputModality, MultiModalityPerceiverFaster
import argparse
from utils.tools import set_seed, create_dir
from custom_model.common_models import Reshape
from custom_model.iclr24_model import SM3TaskHeads, MultiModalityConfig, MultiModalitySequenceMoETransformer, NoisyGate, NoisyVMoEGate, GATES
torch.optim.lr_scheduler.CosineAnnealingLR

SCHEDULERS = {'None': None, 'CosineAnnealingLR': torch.optim.lr_scheduler.CosineAnnealingLR, 'LambdaLR': torch.optim.lr_scheduler.LambdaLR}

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.0005)
# parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--enrico-path', type=str, default='datasets/data/enrico')
parser.add_argument('--avmnist-path', type=str, default='datasets/data/avmnist')
parser.add_argument('--gentle-push-path', type=str, default='datasets/gentle_push/cache')
parser.add_argument('--gentle-push-batch-size', type=int, default=32)
parser.add_argument('--model-path', type=str, default='private_test_scripts/model/medium.pth')
parser.add_argument('--img-path', type=str, default='log/img_three_task')
parser.add_argument('--unlimited-capacity-on-mlp', type=int, default=1)
parser.add_argument('--co-input', type=int, default=1)
parser.add_argument('--seperate-qkv', type=int, default=1)

parser.add_argument("--no_vision", action="store_true")
parser.add_argument("--no_proprioception", action="store_true")
parser.add_argument("--no_haptics", action="store_true")
parser.add_argument("--image_blackout_ratio", type=float, default=0.0)
parser.add_argument("--sequential_image_rate", type=int, default=1)
parser.add_argument("--kloss_dataset", action="store_true")
parser.add_argument('--is-train', type=int, default=1)
parser.add_argument('--gate-type', type=str, default='NoisyGate', choices=GATES.keys())

parser.add_argument('--modality-gating-merge', type=int, default=0)
parser.add_argument('--training-weight', type=float, default=[0.9, 1.1, 1.5], nargs='+')
parser.add_argument('--dynamic-reweight', type=int, default=0)
parser.add_argument('--cross-modality-attn', type=int, default=0)
parser.add_argument('--cross-depth', type=int, default=1)
parser.add_argument('--grad-clip', type=int, default=0)
parser.add_argument('--capacity-ratio', type=float, default=1.)
parser.add_argument('--capacity-ratios', type=float, default=[1, 1, 1], nargs='+')
parser.add_argument('--push-cut-into', type=int, default=0)
parser.add_argument('--tune-gate-weight', type=int, default=0)
parser.add_argument('--mlp-top-k', type=int, default=2)
parser.add_argument('--attn-top-k', type=int, default=2)
parser.add_argument('--moe-gate-weight', type=float, default=0.1)
parser.add_argument('--gradient-blending', type=int, default=0)
parser.add_argument('--gradient-blending-epoch', type=int, default=10)
parser.add_argument('--debug', type=int, default=1)
parser.add_argument('--cross-attn-use-moe', action='store_true')
parser.add_argument('--num-experts', type=int, default=16)
parser.add_argument('--weight-decay', type=float, default=0.0)
parser.add_argument('--num-latent', type=int, default=12)
parser.add_argument('--use-individual-latent-dim', type=int, default=0)
parser.add_argument('--outter-task-loss', type=int, default=0)
parser.add_argument('--grad-clip-value', type=float, default=1.)
parser.add_argument('--lr-schedular', default='None', choices=SCHEDULERS.keys(), type=str)
parser.add_argument('--push-seq-length', type=int, default=16)
parser.add_argument('--only-pretrain', default=0, type=int)
parser.add_argument('--push-without-valid', default=0, type=int)
parser.add_argument('--pretrain-model-path', type=str, default='private_test_scripts/model/three_medium_tasks_pretrain0.8.pth')

args = parser.parse_args()
print(args)
set_seed(args.seed)
create_dir(args.img_path)

is_debug = args.debug == 1
tune_gate_weight = args.tune_gate_weight == 1
is_gradient_blending = args.gradient_blending == 1
only_pretrain=args.only_pretrain == 1
push_without_valid = args.push_without_valid == 1
torch.multiprocessing.set_sharing_strategy('file_system')

isdebug=False

# enrico
from datasets.enrico.get_data import get_dataloader
dls, weights = get_dataloader(args.enrico_path, batch_size=32, num_workers=4, cut_into=True, img_noise=False, wireframe_noise=False, hmmt=False)
trains1, valid1, test1 = dls
test1 = test1['image'][0]
for i in range(len(trains1.dataset[0])):
    print(trains1.dataset[0][i].shape if hasattr(trains1.dataset[0][i], 'shape') else trains1.dataset[0][i])
print("loaded enrico dataset successfully!!!")
# print(test1)
# exit()
# avmnist
from datasets.avmnist.get_data import get_dataloader
trains2,valid2,test2=get_dataloader(args.avmnist_path,flatten_audio=True, unsqueeze_channel=0, batch_size=32, cut_into=True, num_workers=8, debug = is_debug)
print("loaded avmnist dataset successfully!!!")
for i in range(len(trains2.dataset[0])):
    print(trains2.dataset[0][i].shape if hasattr(trains2.dataset[0][i], 'shape') else trains2.dataset[0])

# exit()
from datasets.gentle_push.data_loader import PushTask
import argparse
import fannypack

Task = PushTask

dataset_args = Task.get_dataset_args(args)

fannypack.data.set_cache_path(args.gentle_push_path)


push_cut_into = args.push_cut_into == 1

trains3,valid3,test3 = Task.get_dataloader(args.push_seq_length, batch_size=args.gentle_push_batch_size, 
                                           drop_last=True, test_noises=[0], 
                                           test_multimodal_only=True, 
                                           num_workers=8, debug = is_debug,
                                           cut_into=push_cut_into, pred_last=False,
                                           without_valid=push_without_valid)
# print(trains3.dataset)
# test3 = test3['image'][0]
test3 = test3['multimodal'][0]
print("loaded gentle_push dataset successfully!!!")

for i in range(len(trains3.dataset[0])):
    print(trains3.dataset[0][i].shape if hasattr(trains3.dataset[0][i], 'shape') else trains3.dataset[0][i])

# PushTask.get_dataloader(16, batch_size=18, drop_last=True, test_multimodal_only=True, test_noises=[0])

torch.set_num_threads(1)
device='cpu'
if torch.cuda.is_available():
    device = torch.cuda.current_device()

# gentle_push modality
# [B, 16, 3, 1]
pose_modality = InputModality(
    name='pose',
    input_channels=3,  # number of channels for each token of the input
    input_axis=1,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('pose', pose_modality.input_dim)
# [B, 16, 7, 1]
sensor_modality = InputModality(
    name='sensor',
    input_channels=7,  # number of channels for mono audio
    input_axis=1,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('sensor', pose_modality.input_dim)
# [B, 16, 32, 32, 1]
trajectory_modality = InputModality(
    name='trajectory',
    input_channels=1,  # number of channels for each token of the input
    input_axis=3,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)

if push_cut_into:
    trajectory_modality = InputModality(
    name='trajectory',
    input_channels=16,  # number of channels for each token of the input
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
    )

print('trajectory', trajectory_modality.input_dim)
# [B, 16, 7, 1]

control_modality = InputModality(
    name='control',
    input_channels=7,  # number of channels for each token of the input
    input_axis=1,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)


print('control', control_modality.input_dim)
# enrico modality
image_1_modality = InputModality(
    name='image_1',
    input_channels=384,  # number of channels for each token of the input
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('image_1', image_1_modality.input_dim)
image_2_modality = InputModality(
    name='image_2',
    input_channels=384,  # number of channels for mono audio
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('image_2', image_2_modality.input_dim)
image_modality = InputModality(
    name='image',
    input_channels=16,  # number of channels for each token of the input
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('image', image_modality.input_dim)
audio_modality = InputModality(
    name='audio',
    input_channels=256,  # number of channels for mono audio
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands = 6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=1.,  # maximum frequency, hyperparameter depending on how fine the data is
)
print('audio', audio_modality.input_dim)

mconfig = MultiModalityConfig(capacity_per_expert = 86, 
                              seed = args.seed,
                              img_path = args.img_path, 
                              unlimited_capacity_on_mlp = args.unlimited_capacity_on_mlp == 1, 
                              num_tasks = 3, 
                              co_input = args.co_input == 1,
                              seperate_qkv = args.seperate_qkv == 1,
                              base_capacity = 32, gate = GATES[args.gate_type],
                              modalities_name = ['image_1','image_2','image','audio','pose', 'sensor', 'trajectory', 'control'],
                              attn_modality_specific = True,
                              modality_gating_merge = args.modality_gating_merge == 1,
                              capacity_ratio = args.capacity_ratio,
                              dynamic_reweight = args.dynamic_reweight == 1,
                              cross_modality_attn = args.cross_modality_attn == 1,
                              mlp_top_k = args.mlp_top_k,
                              attn_top_k = args.attn_top_k,
                              cross_attn_use_moe = args.cross_attn_use_moe,
                              num_experts = args.num_experts,
                              use_individual_latent_dim = args.use_individual_latent_dim == 1,
                              outter_task_loss = args.outter_task_loss == 1,
                              grad_clip_value = args.grad_clip_value,
                              capacity_ratios = args.capacity_ratios)
# print(mconfig.unlimited_capacity_on_mlp)
if mconfig.modality_gating_merge:
    mconfig.setting_modality_remap(
        {
            'image_1': "image",
            'image_2': 'set',
            'image': 'image',
            'audio': 'audio',
            'pose': 'pose',
            'sensor': 'sensor',
            'trajectory': 'trajectory',
            'control': 'control'
         }
    )
    
if mconfig.task_gating_merge:
    mconfig.setting_task_remap(
        {
            '0':'0',
            '1':'0',
            '2':'0',
        }
    )


if mconfig.use_individual_latent_dim:
    mconfig.individual_latent_dim = {
        0: 12, 
        1: 12, 
        2: 12
    }
    args.num_latent = 12

model = MultiModalitySequenceMoETransformer(
    modalities=(image_1_modality,image_2_modality,image_modality,audio_modality,pose_modality,sensor_modality,trajectory_modality,control_modality),
    depth=1,  # depth of net, combined with num_latent_blocks_per_layer to produce full Perceiver
    cross_heads=1,  # number of heads for cross attention. paper said 1
    num_latents=args.num_latent,
    modalities_num= [2, 2, 4], 
    # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim=64,  # latent dimension
    latent_heads=8,  # number of heads for latent self attention, 8
    cross_dim_head=64,
    latent_dim_head=64,
    num_classes=2,  # output number of classes
    attn_dropout=0.,
    ff_dropout=0.,
    cross_depth=args.cross_depth,
    weight_tie_layers=True,
    num_latent_blocks_per_layer=1,  # Note that this parameter is 1 in the original Lucidrain implementation
    # whether to weight tie layers (optional, as indicated in the diagram)
    args = mconfig,
).to(device)

from private_test_scripts.mmoe_transformer.multitask_training import train
if only_pretrain:
    model.to_logitslist=torch.nn.ModuleList([
        SM3TaskHeads('push', ['pose', 'sensor', 'trajectory', 'control'], args.push_seq_length, device).to(device)])
    model.modalities_num = [4]
    mconfig.capacity_ratios = None
    # mconfig.capacity_ratio = 1.2
    # pretrain
    train(model, args.epochs, [trains3], [valid3], [test3],\
        [['pose', 'sensor', 'trajectory', 'control']],\
        args.model_path,lr=args.lr,device=device,train_weights=[1.],
        eval_weights=[1.],
        is_affect=[False,False,False], unsqueezing=[False,False,False],transpose=[False,False,False],
        is_train=args.is_train == 1, 
        args = args, mconfig=mconfig,
        calc_flops=False, 
        grad_clip=args.grad_clip == 1,
        tune_gate_weight=tune_gate_weight,
        moe_gate_weight=[args.moe_gate_weight],
        weight_decay=args.weight_decay,
        schedular=None,
        criterions=[torch.nn.MSELoss()],
        is_classification=[False])
else:
    model.to_logitslist=torch.nn.ModuleList([
        SM3TaskHeads('push', ['pose', 'sensor', 'trajectory', 'control'], args.push_seq_length, device).to(device)])
    model.to_logits = model.to_logitslist[0]
    model.load_state_dict(torch.load(args.pretrain_model_path))
    
    model.args.capacity_ratios = args.capacity_ratios
    
if only_pretrain == 1:
    exit()
# model.load_state_dict(torch.load(args.model_path))
pre_head = model.to_logitslist[0]
# SM3TaskHeads
model.to_logitslist=torch.nn.ModuleList(
    [
        SM3TaskHeads('enrico', ['image_1','image_2'], args.push_seq_length, device).to(device),
        SM3TaskHeads('av_mnist', ['image','audio'], args.push_seq_length, device).to(device),
        SM3TaskHeads('push', ['pose', 'sensor', 'trajectory', 'control'], args.push_seq_length, device).to(device),
    ]
)

model.modalities_num=[2, 2, 4]
model.to_logitslist[2] = pre_head
# from private_test_scripts.perceivers.train_structure_multitask import train

eval_weights = [1.0, 1.0, 10]
if push_without_valid:
    eval_weights = [0.5, 1.0, 1000]

train(model, args.epochs, [trains1,trains2,trains3], [valid1,valid2,valid3], [test1,test2,test3],\
    [['image_1','image_2'],['image','audio'],['pose', 'sensor', 'trajectory', 'control']],\
    args.model_path,lr=args.lr,device=device,train_weights=args.training_weight,
    eval_weights=eval_weights,
    is_affect=[False,False,False], unsqueezing=[False,False,False],transpose=[False,False,False],
    is_train=args.is_train == 1, 
    args = args, mconfig=mconfig,
    calc_flops=False, 
    grad_clip=args.grad_clip == 1,
    tune_gate_weight=tune_gate_weight,
    moe_gate_weight=[args.moe_gate_weight] * 3,
    weight_decay=args.weight_decay,
    schedular=SCHEDULERS[args.lr_schedular])
