import os
import torch 
import sys
import time
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DistributedSampler, RandomSampler
from utils.utils import fix_seeds, setup_cudnn, get_logger, cal_flops
from utils.metric import Metrics
from config import ex
from models.audio_video_model import AudioVideoModelWithMT
from dataset.ks import KineticsSound
from dataset.ave import AVE
from dataset.cremad import CREMAD
from utils.optimizers import get_optimizer
from utils.schedulers import get_scheduler
from scripts.test_audio_video_model import evaluate
import wandb
import gc
import pprint
import torch.nn.functional as F


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"Trainable Params: {trainable_params} || All Params: {all_param} || Trainable (%): {100 * trainable_params / all_param:.2f}"
    )


@ex.automain
def main(_config):
    print("Experiment Configurations:")
    print(_config)

    fix_seeds(_config["seed"])
    setup_cudnn()
    wandb_exp_name = _config["wandb_exp_name"]
    if _config['use_wandb']:
        wandb.init(project=_config['wandb_project_name'], entity="ENTITY", name=wandb_exp_name)

    save_dir = Path(_config['save_dir'], wandb_exp_name)
    
    os.makedirs(save_dir, exist_ok=True)
    logger = get_logger(save_dir / 'train.log')
    
    start = time.time()
    best_performance = 0.0
    best_epoch = 0
    num_workers = _config['num_workers']
    device = torch.device(_config['device'])
    num_classes = _config['class_num']
    max_text_len = _config['max_text_len']
    batch_size = _config['batch_size']
    num_epoch = _config['max_epoch']
    enable_lora = _config['enable_lora']
    enable_mt = _config['enable_mt']

    # construct missing modality info
    missing_info = {
        'ratio' : _config["missing_ratio"],
        'type' : _config["missing_type"],
        'both_ratio' : _config["both_ratio"],
        'missing_table_root': _config["missing_table_root"],
        'simulate_missing' : _config["simulate_missing"],
        'only_paired' : _config["only_paired"]
    }

    # Dataset
    if _config['exp_name'] == "finetune_ks":
        ds = KineticsSound
    elif _config['exp_name'] == "finetune_ave":
        ds = AVE
    elif _config['exp_name'] == "finetune_cremad":
        ds = CREMAD

    trainset = ds(
        split="train",
        dataset_root_dir=_config['data_dir'],
        missing_info=missing_info,
    )
    valset = ds(
        split="test",
        dataset_root_dir=_config['data_dir'],
        missing_info=missing_info,
    )

    model = AudioVideoModelWithMT(
        num_classes, 
        r=_config['r'],
        lora_alpha=_config['lora_alpha'],
        lora_dropout=_config['lora_dropout'],
        vit_target_modules = _config['vit_target_modules'],
        ast_target_modules = _config['ast_target_modules'],
        enable_lora=enable_lora,
        enable_mt=enable_mt,
    )
    
    print_trainable_parameters(model)
    print("Total Train Samples:", len(trainset))
    print("Total Test Samples:", len(valset))

    logger.info('================== model structure =====================')
    logger.info(model)
    logger.info('================== training config =====================')
    logger.info(_config)
    logger.info('================== parameter count =====================')
    logger.info(sum(p.numel() for p in model.parameters() if p.requires_grad))
    logger.info(f"Total Train Samples: {len(trainset)}")
    logger.info(f"Total Test Samples: {len(valset)}")

    model = torch.nn.DataParallel(model, device_ids=_config['gpu_ids'])
    model = model.to(device)
    
    iters_per_epoch = len(trainset) // batch_size
    loss_fn = nn.CrossEntropyLoss()
    alignment_loss = nn.MSELoss()
    
    optimizer = get_optimizer(model, _config['optim_type'], _config['learning_rate'], _config['weight_decay'])
    scheduler = get_scheduler(_config['scheduler'], optimizer, int((num_epoch+1)*iters_per_epoch), _config['power'], iters_per_epoch * _config['warmup'], _config['warmup_ratio'])

    sampler = RandomSampler(trainset)
    sampler_val = None
           
    trainloader = DataLoader(trainset, batch_size=_config['batch_size'], num_workers=num_workers, drop_last=True, pin_memory=False, sampler=sampler)
    valloader = DataLoader(valset, batch_size=_config['batch_size'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
    
    scaler = GradScaler(enabled=_config['amp'])
    writer = SummaryWriter(str(save_dir))
    metric = Metrics()
    all_time_best = None
    
    for epoch in range(num_epoch):
        # Clean Memory
        torch.cuda.empty_cache()
        gc.collect()
        model.train()

        train_loss = 0.0    
        total_ce_loss = 0.0
        if enable_mt:
            total_alignment_loss = 0.0
        lr = scheduler.get_lr()
        lr = sum(lr) / len(lr)
        pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")
        
        for iter, sample in pbar:
            optimizer.zero_grad(set_to_none=True)
            video_audio = [sample['video'].to(device), sample['audio'].to(device)]
            lbl = sample['label'].to(device)
            missing_type = sample['missing_type'].to(device)
            
            with autocast(enabled=_config['amp']):
                logits_fused, real_tokens, estimated_tokens, fused_features = model(video_audio, missing_type)
                loss = loss_fn(logits_fused, lbl) 
                
                if enable_mt and len(real_tokens) > 0:
                    align_loss = alignment_loss(torch.stack(real_tokens), torch.stack(estimated_tokens))
                else:
                    align_loss = alignment_loss(torch.zeros(1, 768), torch.zeros(1, 768))

                _, predicted = torch.max(logits_fused, 1)
                metric.update(predicted, lbl)
            
            if enable_mt:
                scaler.scale(loss+_config['mt_alignment_loss_weight']*align_loss).backward()
            else:
                scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # Clear gradients
            optimizer.zero_grad()
            torch.cuda.synchronize()

            lr = scheduler.get_lr()
            lr = sum(lr) / len(lr)
            if lr <= 1e-8:
                lr = 1e-8 # minimum of lr
            
            train_loss += loss.item() 
            total_ce_loss += loss.item() 
            if enable_mt:
                train_loss += _config['mt_alignment_loss_weight']*align_loss.item()
                total_alignment_loss += _config['mt_alignment_loss_weight']*align_loss.item()
        
            # Clean Memory
            torch.cuda.empty_cache()
            gc.collect()

            if enable_mt:
                pbar.set_description(f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Total Loss: {train_loss / (iter+1):.4f} CE Loss: {total_ce_loss / (iter+1):.4f} Align Loss: {total_alignment_loss / (iter+1):.4f}")
            else:
                pbar.set_description(f"Epoch: [{epoch+1}/{num_epoch}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Total Loss: {train_loss / (iter+1):.4f} CE Loss: {total_ce_loss / (iter+1):.4f}")

            if (_config['max_steps'] > 0 and (iter+1) > _config['max_steps']):
                break
        
        train_loss /= iter+1
        total_ce_loss /= iter+1
        if enable_mt:
            total_alignment_loss /= iter+1
        writer.add_scalar('train/loss', train_loss, epoch)

        # Calculate evaluation metrics
        all_scores = metric.compute_score(prefix='train_')
        all_scores.update({'train_loss': train_loss, 'epoch': epoch})
        if enable_mt:
            all_scores.update({'train_ce_loss': total_ce_loss, 'train_alignment_loss': total_alignment_loss})
        else:
            all_scores.update({'train_ce_loss': total_ce_loss})
        metric.reset()

        # Evaluate
        test_scores = evaluate(
            model, valloader, device, num_classes,  
            loss_fn=loss_fn, 
            loss_fn2=alignment_loss if enable_mt else None, 
            loss_fn2_weight=_config['mt_alignment_loss_weight'], 
            is_mmimdb=False, 
            enable_mt=enable_mt,
        )

        # Log to wandb
        all_scores.update(test_scores)
        pprint.pprint(all_scores)
        logger.info(all_scores)

        if _config['use_wandb']:
            wandb.log(all_scores)

        current_performance = all_scores['test_accuracy']

        if best_performance < current_performance:
            all_time_best = all_scores
            prev_best_ckp = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}_checkpoint.pth"
            prev_best = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}.pth"
            if os.path.isfile(prev_best): os.remove(prev_best)
            if os.path.isfile(prev_best_ckp): os.remove(prev_best_ckp)
            best_performance = current_performance
            best_epoch = epoch+1
            cur_best_ckp = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}_checkpoint.pth"
            cur_best = save_dir / f"{wandb_exp_name}_epoch{best_epoch}_{best_performance}.pth"
            # torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), cur_best)
            torch.save(model.module.state_dict(), cur_best)
            # --- 
            torch.save({'epoch': best_epoch,
                        'model_state_dict': model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': train_loss,
                        'scheduler_state_dict': scheduler.state_dict(),
                        'best_performance': best_performance,
                        }, cur_best_ckp)
    # All time best score
    logger.info("Best Score:")
    logger.info(all_time_best)
    
    # Done!
    writer.close()
    pbar.close()
    end = time.gmtime(time.time() - start)
    