import os
import torch 
import sys
import time
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DistributedSampler, RandomSampler
from utils.utils import fix_seeds, setup_cudnn, get_logger, cal_flops
from utils.metric import Metrics
from config import ex
from models.image_text_model import ViTBertMMT
from dataset.food101_dataset import FOOD101Dataset
from dataset.mmimdb_dataset import MMIMDBDataset
from utils.optimizers import get_optimizer
from utils.schedulers import get_scheduler
from scripts.test_image_text_model import evaluate
import wandb
import functools
from pytorch_metric_learning import losses
import gc
import pprint
import torch.nn.functional as F


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"Trainable Params: {trainable_params} || All Params: {all_param} || Trainable (%): {100 * trainable_params / all_param:.2f}"
    )


@ex.automain
def main(_config):
    print("Experiment Configurations:")
    print(_config)
    fix_seeds(_config["seed"])
    setup_cudnn()
    wandb_exp_name = _config["wandb_exp_name"]
    if _config['use_wandb']:
        wandb.init(project=_config['wandb_project_name'], entity="ENTITY", name=wandb_exp_name)

    save_dir = Path(_config['save_dir'], wandb_exp_name)
    
    os.makedirs(save_dir, exist_ok=True)
    logger = get_logger(save_dir / 'train.log')
    
    start = time.time()
    best_performance = 0.0
    best_epoch = 0
    num_workers = _config['num_workers']
    device = torch.device(_config['device'])
    num_classes = _config['class_num']
    max_text_len = _config['max_text_len']
    batch_size = _config['batch_size']
    num_epoch = _config['max_epoch']
    enable_lora = _config['enable_lora']
    is_imdb = False
    enable_mt = _config['enable_mt']
    text_model = _config['text_model']
    if _config['exp_name'] == "finetune_mmimdb":
        is_imdb = True

    # construct missing modality info
    missing_info = {
        'ratio' : _config["missing_ratio"],
        'type' : _config["missing_type"],
        'both_ratio' : _config["both_ratio"],
        'missing_table_root': _config["missing_table_root"],
        'simulate_missing' : _config["simulate_missing"],
        'only_paired' : _config["only_paired"]
    }

    # Dataset
    if _config['exp_name'] == "finetune_mmimdb":
        ds = MMIMDBDataset
    elif _config['exp_name'] == "finetune_food101":
        ds = FOOD101Dataset
    else:
        sys.exit("No valid experiment selected. Aborting!") 
    
    trainset = ds(
        _config['data_dir'],
        _config['train_transform_keys'],
        split="train",
        image_size=_config['image_size'],
        max_text_len=_config['max_text_len'],
        draw_false_image=_config['draw_false_image'],
        draw_false_text=_config['draw_false_text'],
        image_only=False,
        missing_info=missing_info,
        enable_mt=enable_mt,
        text_model=text_model,
    )
    valset = ds(
        _config['data_dir'],
        _config['train_transform_keys'],
        split="test",
        image_size=_config['image_size'],
        max_text_len=_config['max_text_len'],
        draw_false_image=_config['draw_false_image'],
        draw_false_text=_config['draw_false_text'],
        image_only=False,
        missing_info=missing_info,
        enable_mt=enable_mt,
        text_model=text_model,
    )

    model = ViTBertMMT(
        num_classes, 
        max_text_len, 
        r=_config['r'],
        lora_alpha=_config['lora_alpha'],
        lora_dropout=_config['lora_dropout'],
        vit_target_modules = _config['vit_target_modules'],
        bert_target_modules = _config['bert_target_modules'],
        enable_lora=enable_lora,
        enable_mt=enable_mt,
        text_model=text_model,
    ) 

    print_trainable_parameters(model)
    print("Total Train Samples:", len(trainset))
    print("Total Test Samples:", len(valset))

    logger.info('================== model structure =====================')
    logger.info(model)
    logger.info('================== training config =====================')
    logger.info(_config)
    logger.info('================== parameter count =====================')
    logger.info(sum(p.numel() for p in model.parameters() if p.requires_grad))
    logger.info(f"Total Train Samples: {len(trainset)}")
    logger.info(f"Total Test Samples: {len(valset)}")

    model = torch.nn.DataParallel(model, device_ids=_config['gpu_ids'])
    model = model.to(device)
    
    iters_per_epoch = len(trainset) // batch_size
    if is_imdb:
        loss_fn = nn.BCEWithLogitsLoss()
    else:
        loss_fn = nn.CrossEntropyLoss()
    alignment_loss = nn.MSELoss()
    
    optimizer = get_optimizer(model, _config['optim_type'], _config['learning_rate'], _config['weight_decay'])
    scheduler = get_scheduler(_config['scheduler'], optimizer, int((num_epoch+1)*iters_per_epoch), _config['power'], iters_per_epoch * _config['warmup'], _config['warmup_ratio'])

    sampler = RandomSampler(trainset)
    sampler_val = None
           
    trainloader = DataLoader(trainset, batch_size=_config['batch_size'], num_workers=num_workers, drop_last=True, pin_memory=False, sampler=sampler)
    valloader = DataLoader(valset, batch_size=_config['batch_size'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
    tokenizer = trainset.tokenizer
    image_processor = trainset.image_processor

    scaler = GradScaler(enabled=_config['amp'])
    writer = SummaryWriter(str(save_dir))
    metric = Metrics()
    all_time_best = None
    
    for epoch in range(num_epoch):
        # Clean Memory
        torch.cuda.empty_cache()
        gc.collect()
        model.train()

        train_loss = 0.0    
        total_ce_loss = 0.0
        if enable_mt:
            total_alignment_loss = 0.0
        lr = scheduler.get_lr()
        lr = sum(lr) / len(lr)
        pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")
        
        for iter, sample in pbar:
            images = image_processor(sample['image'], return_tensors="pt", do_rescale=False).to(device)
            texts = tokenizer(
                sample['text'],
                return_tensors='pt', 
                padding=True, 
                truncation=True, 
                max_length=_config['max_text_len'],
            ) 
            missing_type = sample['missing_type'].to(device)

            texts['input_ids'] = texts['input_ids'].to(device)
            texts['attention_mask'] = texts['attention_mask'].to(device)
            if text_model == 'bert-base-uncased':
                texts['token_type_ids'] = texts['token_type_ids'].to(device)
            
            optimizer.zero_grad(set_to_none=True)
            image_text = [images, texts]
            
            if is_imdb:
                lbl = torch.stack(sample['label'], dim=1).to(device).float()
            else:
                lbl = sample['label'].to(device)
            
            with autocast(enabled=_config['amp']):
                logits_fused, real_tokens, estimated_tokens, _ = model(image_text, missing_type)
                loss = loss_fn(logits_fused, lbl) 

                if enable_mt and len(real_tokens) > 0:
                    align_loss = alignment_loss(torch.stack(real_tokens), torch.stack(estimated_tokens))
                else:
                    align_loss = alignment_loss(torch.zeros(1, 768), torch.zeros(1, 768))

                if is_imdb:
                    predicted = torch.sigmoid(logits_fused).round().detach().cpu().numpy()
                else:
                    _, predicted = torch.max(logits_fused, 1)
                metric.update(predicted, lbl)
            
            if enable_mt:
                scaler.scale(loss+_config['mt_alignment_loss_weight']*align_loss).backward()
            else:
                scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # Clear gradients
            optimizer.zero_grad()
            torch.cuda.synchronize()

            lr = scheduler.get_lr()
            lr = sum(lr) / len(lr)
            if lr <= 1e-8:
                lr = 1e-8 # minimum of lr
            
            train_loss += loss.item() 
            total_ce_loss += loss.item() 
            if enable_mt:
                train_loss += _config['mt_alignment_loss_weight']*align_loss.item()
                total_alignment_loss += _config['mt_alignment_loss_weight']*align_loss.item()
        
            # Clean Memory
            torch.cuda.empty_cache()
            gc.collect()

            if enable_mt:
                pbar.set_description(f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Total Loss: {train_loss / (iter+1):.4f} CE Loss: {total_ce_loss / (iter+1):.4f} Align Loss: {total_alignment_loss / (iter+1):.4f}")
            else:
                pbar.set_description(f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Total Loss: {train_loss / (iter+1):.4f} CE Loss: {total_ce_loss / (iter+1):.4f}")

            if (_config['max_steps'] > 0 and (iter+1) > _config['max_steps']):
                break
        
        train_loss /= iter+1
        total_ce_loss /= iter+1
        if enable_mt:
            total_alignment_loss /= iter+1
        writer.add_scalar('train/loss', train_loss, epoch)

        # Calculate evaluation metrics
        all_scores = metric.compute_score(prefix='train_')
        all_scores.update({'train_loss': train_loss, 'epoch': epoch})
        if enable_mt:
            all_scores.update({'train_ce_loss': total_ce_loss, 'train_alignment_loss': total_alignment_loss})
        else:
            all_scores.update({'train_ce_loss': total_ce_loss})
        metric.reset()

        # Evaluate
        test_scores = evaluate(
            model, valloader, device, num_classes, 
            _config['max_text_len'], 
            loss_fn=loss_fn, 
            loss_fn2=alignment_loss if enable_mt else None, 
            loss_fn2_weight=_config['mt_alignment_loss_weight'], 
            is_mmimdb=is_imdb, 
            enable_auroc=False,
            enable_mt=enable_mt,
            text_model=text_model,
        )

        # Log to wandb
        all_scores.update(test_scores)
        pprint.pprint(all_scores)
        logger.info(all_scores)

        if _config['use_wandb']:
            wandb.log(all_scores)

        if _config['exp_name'] == "finetune_mmimdb":
            current_performance = all_scores['test_f1_macro']
        elif _config['exp_name'] == "finetune_food101":
            current_performance = all_scores['test_accuracy']
        else:
            sys.exit("No valid experiment selected. Aborting!") 

        if best_performance < current_performance:
            all_time_best = all_scores
            prev_best_ckp = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}_checkpoint.pth"
            prev_best = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}.pth"
            if os.path.isfile(prev_best): os.remove(prev_best)
            if os.path.isfile(prev_best_ckp): os.remove(prev_best_ckp)
            best_performance = current_performance
            best_epoch = epoch+1
            cur_best_ckp = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}_checkpoint.pth"
            cur_best = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}.pth"
            # torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), cur_best)
            torch.save(model.module.state_dict(), cur_best)
            # --- 
            torch.save({'epoch': best_epoch,
                        'model_state_dict': model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': train_loss,
                        'scheduler_state_dict': scheduler.state_dict(),
                        'best_performance': best_performance,
                        }, cur_best_ckp)
    # All time best score
    logger.info("Best Score:")
    logger.info(all_time_best)
    
    # Done!
    writer.close()
    pbar.close()
    end = time.gmtime(time.time() - start)
    