import warnings
warnings.filterwarnings("ignore")

import torch
import sys
import os
sys.path.append(os.getcwd())
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
import argparse
import random
import numpy as np
from easydict import EasyDict as edict

from model.common_models import MLP, Transformer  # noqa
from training_structures.unimodal import train, test # noqa
from datasets.affect.data_loader import MMDataLoader
from training_structures.utils.functions import ConfigParser

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, 
                       default='config_complete.json'
                       )
parser.add_argument('--dataset_name', type=str, 
                       default='mosi'
                       )
parser.add_argument('--model_name', type=str, 
                       default='unimodaltransformer'
                       )
parser.add_argument('--modality_num', type=int, 
                       default=0,
                       )
parser.add_argument('--train_ratio', type=float, 
                       default=1.0
                       )
parser.add_argument('--batch_size', type=int, 
                        default=32
                       )
parser.add_argument('--num_workers', type=int, 
                       default=4
                    )
parser.add_argument('--input_dim', type=int, 
                       default=768,
                       choices=[20, 5, 768]
                       )
parser.add_argument('--encoder_hidden_dim', type=int, 
                       default=512,
                       )
parser.add_argument('--head_hidden_dim', type=int, 
                       default=256,
                       )
parser.add_argument('--train_mode', type=str, 
                       default='regression',
                       choices=['regression', 'classification']
                       )
parser.add_argument('--epoch', type=int, 
                       default=100,
                       )
parser.add_argument('--lr', type=float, 
                        default=1e-4,
                       )
parser.add_argument('--early_stop', action='store_true'
                       )
parser.add_argument('--patience', type=int, 
                        default=10,
                       )
parser.add_argument('--weight_decay', type=float, 
                        default=0.005,
                       )
parser.add_argument('--clip_value', type=float, 
                        default=1.0,
                       )
parser.add_argument('--save_dir', type=str, 
                       default='results/saved_models',
                       )
parser.add_argument('--device', type=str, default='cuda'
                    )
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()


def main(args):
    config = ConfigParser(args.config_path)
    hyparams = config.get_params(args.dataset_name, args.model_name)
    args = edict(vars(args))  
    args.update(hyparams)

    use_video, use_audio, use_text = [i == args.modality_num for i in range(3)]
    args.use_bert = True if use_text else False
    print(f"Arguments: {args}")
    
    dataloader = MMDataLoader(args, train_ratio=args.train_ratio, use_video=use_video, use_audio=use_audio, use_text=use_text)
    
    encoder = Transformer(args, n_features=args.input_dim, dim=args.encoder_hidden_dim, nhead=args.num_heads, num_layers=args.nlevels).to(args.device)
    output_dim = args.num_classes if args.train_mode == "classification" else 1
    head = MLP(args.encoder_hidden_dim, args.head_hidden_dim, output_dim).to(args.device)

    exp_name = f"modal{args.modality_num}_ratio{args.train_ratio}_seed{args.seed}"

    exp_dir = os.path.join(args.save_dir, 'unimodal', exp_name)
    os.makedirs(exp_dir, exist_ok=True) 
    saved_encoder = os.path.join(exp_dir, 'encoder.pt')
    saved_head = os.path.join(exp_dir, 'head.pt')

    print("Starting Training...")
    train(encoder, head, dataloader['train'], dataloader['valid'], args.epoch, early_stop=args.early_stop, patience=args.patience, optimtype=torch.optim.AdamW, lr=args.lr, weight_decay=args.weight_decay, clip_value=args.clip_value, criterion=torch.nn.L1Loss(), save_encoder=saved_encoder, save_head=saved_head, task=args.train_mode, device=args.device)

    print("Loading best models for testing...")
    best_encoder = torch.load(saved_encoder).to(args.device)
    best_head = torch.load(saved_head).to(args.device)

    print("Starting Testing...")
    test_result = test(
        best_encoder,
        best_head,
        dataloader['test'],
        task=args.train_mode,
        criterion=torch.nn.L1Loss(),
        device=args.device,
        dataset=args.dataset_name
    )

    summary_file = os.path.join(args.save_dir, 'unimodal', 'all_unimodal_experiments_summary.txt')
    with open(summary_file, 'a') as f:
        f.write(f"{exp_name}: {test_result}\n")


if __name__ == "__main__":
  set_seed(args.seed)
  main(args)
