import os
import torch 
import sys
import time
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DistributedSampler, RandomSampler
from utils.utils import fix_seeds, setup_cudnn, get_logger, cal_flops
from utils.metric import Metrics
from config import ex
from models.unimodal import *
from dataset.food101_dataset import FOOD101Dataset
from dataset.mmimdb_dataset import MMIMDBDataset
from dataset.ks import KineticsSound
from dataset.ave import AVE
from dataset.creamad import CREAMAD
from utils.optimizers import get_optimizer
from utils.schedulers import get_scheduler
from scripts.test_unimodal import evaluate
import wandb
import functools
from pytorch_metric_learning import losses
import gc
import pprint
import torch.nn.functional as F


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"Trainable Params: {trainable_params} || All Params: {all_param} || Trainable (%): {100 * trainable_params / all_param:.2f}"
    )


@ex.automain
def main(_config):
    print("Experiment Configurations:")
    print(_config)
    fix_seeds(_config["seed"])
    setup_cudnn()
    wandb_exp_name = _config["wandb_exp_name"]
    if _config['use_wandb']:
        wandb.init(project=_config['wandb_project_name'], entity="ENTITY", name=wandb_exp_name)

    save_dir = Path(_config['save_dir'], wandb_exp_name)
    
    os.makedirs(save_dir, exist_ok=True)
    logger = get_logger(save_dir / 'train.log')
    
    start = time.time()
    best_performance = 0.0
    best_epoch = 0
    num_workers = _config['num_workers']
    device = torch.device(_config['device'])
    num_classes = _config['class_num']
    max_text_len = _config['max_text_len']
    batch_size = _config['batch_size']
    num_epoch = _config['max_epoch']
    enable_lora = _config['enable_lora']
    is_imdb = False
    if _config['exp_name'] == "finetune_mmimdb":
        is_imdb = True
    model_name = _config['model_name']

    # construct missing modality info
    missing_info = {
        'ratio' : _config["missing_ratio"],
        'type' : _config["missing_type"],
        'both_ratio' : _config["both_ratio"],
        'missing_table_root': _config["missing_table_root"],
        'simulate_missing' : _config["simulate_missing"],
        'only_paired' : _config["only_paired"]
    }

    # Dataset
    dataset_type = 1 # 1 -> Image-Text, 2 -> Audio-Video
    if _config['exp_name'] == "finetune_mmimdb":
        ds = MMIMDBDataset
    elif _config['exp_name'] == "finetune_food101":
        ds = FOOD101Dataset
    elif _config['exp_name'] == "finetune_ks":
        ds = KineticsSound
        dataset_type = 2
    elif _config['exp_name'] == "finetune_ave":
        ds = AVE
        dataset_type = 2
    elif _config['exp_name'] == "finetune_creamad":
        ds = CREAMAD
        dataset_type = 2
    else:
        sys.exit("No valid experiment selected. Aborting!") 
    
    if dataset_type == 1:
        trainset = ds(
            f"{_config['data_root']}/{_config['datasets'][0]}",
            _config['train_transform_keys'],
            split="train",
            image_size=_config['image_size'],
            max_text_len=_config['max_text_len'],
            draw_false_image=_config['draw_false_image'],
            draw_false_text=_config['draw_false_text'],
            image_only=_config['image_only'],
            missing_info=missing_info,
        )
        valset = ds(
            f"{_config['data_root']}/{_config['datasets'][0]}",
            _config['train_transform_keys'],
            split="test",
            image_size=_config['image_size'],
            max_text_len=_config['max_text_len'],
            draw_false_image=_config['draw_false_image'],
            draw_false_text=_config['draw_false_text'],
            image_only=_config['image_only'],
            missing_info=missing_info,
        )
    else:
        trainset = ds(
            split="train",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        )
        valset = ds(
            split="test",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        )

    if model_name == 'bert':
        model = BertClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['bert_target_modules'],
            enable_lora=enable_lora,
        )
    elif model_name == 'vit':
        model = ViTClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['vit_target_modules'],
            enable_lora=enable_lora,
        )
    elif model_name == 'ast':
        model = ASTClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['vit_target_modules'],
            enable_lora=enable_lora,
        )
    elif model_name == 'video':
        model = VideoClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['vit_target_modules'],
            enable_lora=enable_lora,
        )
    
    # Freeze pretrained models
    if not enable_lora:
        if model_name == 'vit':
            for param in model.vit.parameters():
                param.requires_grad = False
        elif model_name == 'bert':
            for param in model.bert.parameters():
                param.requires_grad = False
        elif model_name == 'ast':
            for param in model.ast.parameters():
                param.requires_grad = False
        elif model_name == 'video':
            for param in model.vit.parameters():
                param.requires_grad = False

    print_trainable_parameters(model)
    print("Total Train Samples:", len(trainset))
    print("Total Test Samples:", len(valset))

    logger.info('================== model structure =====================')
    logger.info(model)
    logger.info('================== training config =====================')
    logger.info(_config)
    logger.info('================== parameter count =====================')
    logger.info(sum(p.numel() for p in model.parameters() if p.requires_grad))
    logger.info(f"Total Train Samples: {len(trainset)}")
    logger.info(f"Total Test Samples: {len(valset)}")

    model = torch.nn.DataParallel(model, device_ids=_config['gpu_ids'])
    model = model.to(device)
    
    iters_per_epoch = len(trainset) // batch_size
    if is_imdb:
        loss_fn = nn.BCEWithLogitsLoss()
    else:
        loss_fn = nn.CrossEntropyLoss()

    optimizer = get_optimizer(model, _config['optim_type'], _config['learning_rate'], _config['weight_decay'])
    scheduler = get_scheduler(_config['scheduler'], optimizer, int((num_epoch+1)*iters_per_epoch), _config['power'], iters_per_epoch * _config['warmup'], _config['warmup_ratio'])

    sampler = RandomSampler(trainset)
    sampler_val = None
           
    trainloader = DataLoader(trainset, batch_size=_config['batch_size'], num_workers=num_workers, drop_last=True, pin_memory=False, sampler=sampler)
    valloader = DataLoader(valset, batch_size=_config['batch_size'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
    
    if dataset_type == 1:
        tokenizer = trainset.tokenizer
        image_processor = trainset.image_processor

    scaler = GradScaler(enabled=_config['amp'])
    writer = SummaryWriter(str(save_dir))
    metric = Metrics()
    all_time_best = None
    
    for epoch in range(num_epoch):
        # Clean Memory
        torch.cuda.empty_cache()
        gc.collect()
        model.train()

        train_loss = 0.0    
        total_ce_loss = 0.0
        lr = scheduler.get_lr()
        lr = sum(lr) / len(lr)
        pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")
        
        for iter, sample in pbar:
            if dataset_type == 1:
                images = image_processor(sample['image'], return_tensors="pt", do_rescale=False).to(device)
                texts = tokenizer(
                    sample['text'],
                    return_tensors='pt', 
                    padding=True, 
                    truncation=True, 
                    max_length=_config['max_text_len'],
                ) 
                texts['input_ids'] = texts['input_ids'].to(device)
                texts['attention_mask'] = texts['attention_mask'].to(device)
                texts['token_type_ids'] = texts['token_type_ids'].to(device)
                
                image_text = [images, texts]
            else:
                video_audio = [sample['video'].to(device), sample['audio'].to(device)]
            
            if is_imdb:
                lbl = torch.stack(sample['label'], dim=1).to(device).float()
            else:
                lbl = sample['label'].to(device)
            
            with autocast(enabled=_config['amp']):
                logits, _ = model(image_text if dataset_type == 1 else video_audio)
                loss = loss_fn(logits, lbl) 
                
                # Performance calculation
                if is_imdb:
                    predicted = torch.sigmoid(logits).round().detach().cpu().numpy()
                else:
                    _, predicted = torch.max(logits, 1)
                metric.update(predicted, lbl)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            optimizer.zero_grad()
            torch.cuda.synchronize()

            lr = scheduler.get_lr()
            lr = sum(lr) / len(lr)
            if lr <= 1e-8:
                lr = 1e-8 # minimum of lr
            
            train_loss += loss.item() 
            total_ce_loss += loss.item() 
        
            # Clean Memory
            torch.cuda.empty_cache()
            gc.collect()

            pbar.set_description(f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Total Loss: {train_loss / (iter+1):.4f}")

            if (_config['max_steps'] > 0 and (iter+1) > _config['max_steps']):
                break
        
        train_loss /= iter+1
        total_ce_loss /= iter+1
        writer.add_scalar('train/loss', train_loss, epoch)

        # Calculate evaluation metrics
        all_scores = metric.compute_score(prefix='train_', enable_auroc=True if _config['exp_name'] == "finetune_hatememes" else False)
        all_scores.update({'train_loss': train_loss, 'epoch': epoch})
        metric.reset()

        # Evaluate
        test_scores = evaluate(
            model, valloader, device, num_classes, 
            _config['max_text_len'], 
            loss_fn=loss_fn, 
            loss_fn2=None, 
            loss_fn2_weight=0.0, 
            is_mmimdb=is_imdb, 
            enable_auroc=False,
            dataset_type=dataset_type
        )

        # Log to wandb
        all_scores.update(test_scores)
        pprint.pprint(all_scores)
        logger.info(all_scores)

        if _config['use_wandb']:
            wandb.log(all_scores)

        if _config['exp_name'] == "finetune_mmimdb":
            current_performance = all_scores['test_f1_macro']
        else:
            current_performance = all_scores['test_accuracy']

        if best_performance < current_performance:
            all_time_best = all_scores
            prev_best_ckp = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}_checkpoint.pth"
            prev_best = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}.pth"
            if os.path.isfile(prev_best): os.remove(prev_best)
            if os.path.isfile(prev_best_ckp): os.remove(prev_best_ckp)
            best_performance = current_performance
            best_epoch = epoch+1
            cur_best_ckp = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}_checkpoint.pth"
            cur_best = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}.pth"
            # torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), cur_best)
            torch.save(model.module.state_dict(), cur_best)
            # --- 
            torch.save({'epoch': best_epoch,
                        'model_state_dict': model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': train_loss,
                        'scheduler_state_dict': scheduler.state_dict(),
                        'best_performance': best_performance,
                        }, cur_best_ckp)

    # All time best score
    logger.info("Best Score:")
    logger.info(all_time_best)

    # Done!
    writer.close()
    pbar.close()
    end = time.gmtime(time.time() - start)
    