#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from __future__ import print_function
import argparse
import os
import sys
import time
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import scipy.io as io
import numpy as np
from torch.autograd import Variable
from dataloader import *
from model import AGOPMIPL
from utils import *
from rfm import compute_agop, extract_bag_features
import warnings

warnings.filterwarnings('ignore')
parser = argparse.ArgumentParser(description='AGOPMIPL - AGOP based Multi-Instance Partial Label Learning')

# Basic training arguments
parser.add_argument('--no-cuda', action='store_true', default=False)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--reg', type=float, default=1e-5)
parser.add_argument('--L', type=int, default=128)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--data_path', type=str, default='./data')
parser.add_argument('--index', type=str, default='index')
parser.add_argument('--ds', type=str, default='CRC-MIPL-SBN')
parser.add_argument('--ds_suffix', type=str, default='')
parser.add_argument('--nr_fea', type=int, default=256)
parser.add_argument('--nr_class', type=int, default=7)
parser.add_argument('--normalize', action='store_true')

parser.add_argument('--mu', type=float, default=0.1, help='sparsity loss weight')
parser.add_argument('--gamma', type=float, default=0.5, help='inhibition loss weight')
parser.add_argument('--inst_weight', type=float, default=0.5, help='instance aux loss weight')
parser.add_argument('--proto_agg', type=str, default='mean', choices=['mean', 'linear'], help='prototype aggregation: mean (V8) or linear (V1)')
parser.add_argument('--attn_lambda', type=float, default=0.3, help='weight for raw attention path in dual-path attention')

# AGOP arguments
parser.add_argument('--agop_rounds', type=int, default=3)
parser.add_argument('--epochs_per_round', type=int, default=100)
parser.add_argument('--agop_momentum', type=float, default=0.0)
parser.add_argument('--no_agop', action='store_true', default=False)
parser.add_argument('--no_disamb', action='store_true', default=False)

parser.add_argument('--nr_samples', type=int, default=10, help='Bayesian classifier samples')
parser.add_argument('--kld_weight', type=float, default=0.01, help='KL divergence weight')

parser.add_argument('--lr_warmup_epochs', type=int, default=10, help='LR warmup epochs per RFM round')
parser.add_argument('--alpha_warmup_ratio', type=float, default=0.2, help='Alpha warmup ratio (first 20% epochs)')

args = parser.parse_args()

seed_everything(args.seed)
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.cuda:
    print('\nGPU is available!')

all_folds = ['index1.mat', 'index2.mat', 'index3.mat', 'index4.mat', 'index5.mat',
             'index6.mat', 'index7.mat', 'index8.mat', 'index9.mat', 'index10.mat']


def adjust_alpha_progressive(epochs, warmup_ratio=0.2):
    """
    Progressive alpha decay with warmup.
    
    First warmup_ratio of epochs: alpha = 1.0 (trust candidate set)
    Remaining: linear decay from 1.0 to 0.0
    """
    warmup_epochs = int(warmup_ratio * epochs)
    decay_epochs = epochs - warmup_epochs
    
    alphas = []
    for ep in range(epochs):
        if ep < warmup_epochs:
            alphas.append(1.0)
        else:
            curr_decay = ep - warmup_epochs
            alpha = 1.0 - (curr_decay / decay_epochs) if decay_epochs > 0 else 0.0
            alphas.append(max(0.0, alpha))
    return alphas


def get_warmup_scheduler(optimizer, warmup_epochs, total_epochs, base_lr):
    """
    Create LR scheduler with warmup.
    
    Warmup phase: LR increases from 0 to base_lr
    After warmup: Cosine annealing decay
    """
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            # Linear warmup
            return float(current_epoch + 1) / float(max(1, warmup_epochs))
        else:
            # Cosine decay after warmup
            progress = (current_epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs)
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    return lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)


def evaluate(train_loader, loader, model, verbose=True):
    """Model evaluation."""
    model.eval()
    all_true_bag_lab = []
    all_pred_bag_prob = np.empty((0, args.nr_class))
    proto_matrix = train_loader.prototypes_matrix.clone().detach().to(device).to(torch.float32)
    
    with torch.no_grad():
        for data, _, true_bag_lab, _ in loader:
            data = data.to(device).to(torch.float32)
            true_bag_lab_np = true_bag_lab.detach().cpu().numpy().astype(np.float64)
            
            output = model.evaluate_objective(data, proto_matrix)
            output_np = output.detach().cpu().numpy()
            all_pred_bag_prob = np.vstack((all_pred_bag_prob, output_np))
            all_true_bag_lab.append(true_bag_lab_np)
    
    all_true_bag_lab = np.array(all_true_bag_lab)
    pred_labels = np.argmax(all_pred_bag_prob, axis=1)
    true_labels = np.argmax(all_true_bag_lab, axis=1)
    accuracy = np.mean(pred_labels == true_labels)
    
    if verbose:
        print(f'\tAccuracy: {accuracy:.3f}')
    
    return accuracy


def train(train_loader, epoch, alpha_list, current_lr):
    """Training function."""
    model.train()
    train_loss = 0.
    proto_matrix = train_loader.prototypes_matrix.clone().detach().to(device).to(torch.float32)
    
    alpha = alpha_list[epoch - 1] if epoch <= len(alpha_list) else alpha_list[-1]
    
    for data, partial_bag_lab, true_bag_lab, index in train_loader:
        if args.cuda:
            data, partial_bag_lab = data.cuda(), partial_bag_lab.cuda()

        data = Variable(data).to(torch.float32)
        partial_bag_lab = Variable(partial_bag_lab).to(torch.float32)
        
        optimizer.zero_grad()
        
        loss, prediction, attention, H_prob = model.calculate_objective(
            data, partial_bag_lab, proto_matrix, args
        )
        
        train_loss += loss.item()
        
        # Progressive disambiguation
        if not args.no_disamb:
            with torch.no_grad():
                pred = prediction if prediction.dim() == 2 else prediction.unsqueeze(0)
                partial_lab = partial_bag_lab.reshape(1, -1)
                new_label = model.regenerate_soft_labels(partial_lab, pred.detach(), alpha)
                new_label = new_label.squeeze(0).cpu().numpy()
            train_loader.train_partial_bag_lab_list[index] = torch.tensor(new_label)
        
        loss.backward()
        optimizer.step()
    
    train_loss /= len(train_loader)
    
    if epoch == 1 or epoch % 10 == 0:
        print(f'Epoch: {epoch}, Loss: {train_loss:.4f}, LR: {current_lr:.6f}, alpha: {alpha:.4f}')
    
    return model, train_loss


import math

if __name__ == "__main__":
    time_s = time.time()
    
    num_trial = 1
    num_fold = len(all_folds)
    data_path = os.path.join(args.data_path, args.ds)
    index_path = os.path.join(data_path, args.index)
    
    if args.ds_suffix:
        mat_name = args.ds + '_' + args.ds_suffix + '.mat'
    else:
        mat_name = args.ds + '.mat'
    mat_path = os.path.join(data_path, mat_name)
    
    print(f'\n{"="*60}')
    print(f' AGOPMIPL: AGOP based Multi-Instance Partial Label Learning')
    print(f'{"="*60}')
    print(f'Dataset: {args.ds}')
    print(f'Mat file: {mat_path}')
    print(f'AGOP rounds: {args.agop_rounds}, Epochs/round: {args.epochs_per_round}')
    print(f'\n=== Settings ===')
    print(f'Bayesian samples: {args.nr_samples}')
    print(f'KLD weight: {args.kld_weight}')
    print(f'LR warmup: {args.lr_warmup_epochs} epochs')
    print(f'Alpha warmup: {args.alpha_warmup_ratio*100:.0f}% of epochs')
    print(f'Sparsity (mu): {args.mu}, Inhibition (gamma): {args.gamma}')
    
    if not os.path.exists(mat_path):
        print(f'\n[ERROR] Mat file not found: {mat_path}')
        sys.exit(1)
    
    # Auto-detect dimensions
    data_mat = io.loadmat(mat_path)
    data = data_mat['data']
    
    for i in range(data.shape[0]):
        if data[i, 0].shape[0] > 0:
            actual_nr_fea = data[i, 0].shape[1]
            if actual_nr_fea != args.nr_fea:
                print(f'[AUTO] Feature dim: {actual_nr_fea}')
                args.nr_fea = actual_nr_fea
            break
    
    for key in ['bag_lab', 'partial_bag_lab']:
        if key in data_mat:
            lab_data = data_mat[key]
            if hasattr(lab_data, 'shape') and len(lab_data.shape) >= 2 and lab_data.shape[1] > 1:
                if lab_data.shape[1] != args.nr_class:
                    print(f'[AUTO] Classes: {lab_data.shape[1]}')
                    args.nr_class = lab_data.shape[1]
                break
    
    all_ins_fea, bag_idx_of_ins, dummy_ins_lab, bag_lab, partial_bag_lab, partial_bag_lab_processed = load_data_mat(
        mat_path, args.nr_fea, args.nr_class, normalize=args.normalize)
    
    print(f'\nData loaded: {len(bag_lab)} bags')
    
    all_fold_results = []
    
    for trial_i in range(num_trial):
        for fold_i in range(num_fold):
            print(f'\n{"="*50}')
            print(f'Trial {trial_i + 1}, Fold {fold_i + 1}')
            print(f'{"="*50}')
            
            idx_file = index_path + '/' + all_folds[fold_i]
            
            if not os.path.exists(idx_file):
                print(f'[WARN] Index not found: {idx_file}')
                continue
            
            idx_tr, idx_te = load_idx_mat(idx_file)
            print(f'Train: {len(idx_tr)}, Test: {len(idx_te)}')
            
            train_set = MIPMLDataloader(all_ins_fea, bag_idx_of_ins, dummy_ins_lab, bag_lab, 
                                        partial_bag_lab_processed, idx_tr, idx_te, args.nr_fea, 
                                        train=True, normalize=args.normalize)
            test_set = MIPMLDataloader(all_ins_fea, bag_idx_of_ins, dummy_ins_lab, bag_lab, 
                                       partial_bag_lab_processed, idx_tr, idx_te, args.nr_fea, 
                                       train=False, normalize=args.normalize)

            train_set.clear_classes(args)
            train_set.classify_by_partial_label(args)
            train_set.generate_prototypes(args, device)
            
            # Initialize AGOPMIPL model
            model = AGOPMIPL(args)
            if args.cuda:
                model.cuda()
            
            print(f'[Model] AGOPMIPL initialized')
            
            best_acc = 0.0
            total_epochs = 0
            
            for agop_round in range(args.agop_rounds):
                print(f'\n--- AGOP Round {agop_round + 1}/{args.agop_rounds} ---')
                
                # Reset optimizer with warmup scheduler for each RFM round
                optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, 
                                     nesterov=True, weight_decay=args.reg)
                scheduler = get_warmup_scheduler(optimizer, args.lr_warmup_epochs, 
                                                 args.epochs_per_round, args.lr)
                
                alpha_list = adjust_alpha_progressive(args.epochs_per_round, args.alpha_warmup_ratio)
                
                for epoch in range(1, args.epochs_per_round + 1):
                    total_epochs += 1
                    current_lr = optimizer.param_groups[0]['lr']
                    model, loss = train(train_set, epoch, alpha_list, current_lr)
                    scheduler.step()
                    
                    if epoch % 20 == 0 or epoch == args.epochs_per_round:
                        print(f'\n[Eval] Round {agop_round + 1}, Epoch {epoch}')
                        print('  Train:', end='')
                        train_acc = evaluate(train_set, train_set, model)
                        print('  Test:', end='')
                        test_acc = evaluate(train_set, test_set, model)
                        
                        if test_acc > best_acc:
                            best_acc = test_acc
                            print(f'  * Best: {best_acc:.3f}')
                
                # AGOP update
                if agop_round < args.agop_rounds - 1 and not args.no_agop:
                    agop = compute_agop(model, extract_bag_features, train_set, device)
                    model.update_agop(agop, momentum=args.agop_momentum)
            
            # Final results
            print(f'\n{"="*50}')
            print(f'Final Results')
            print(f'{"="*50}')
            print('Train:', end='')
            final_train = evaluate(train_set, train_set, model)
            print('Test:', end='')
            final_test = evaluate(train_set, test_set, model)
            print(f'Best Test: {best_acc:.3f}')
            
            all_fold_results.append({
                'fold': fold_i + 1,
                'final': final_test,
                'best': best_acc
            })
            
            torch.cuda.empty_cache()
    
    # Summary
    print(f'\n{"="*60}')
    print(f'EXPERIMENT SUMMARY')
    print(f'{"="*60}')
    if all_fold_results:
        best_accs = [r['best'] for r in all_fold_results]
        print(f'Mean Best Accuracy: {np.mean(best_accs):.3f} ± {np.std(best_accs):.3f}')
        for r in all_fold_results:
            print(f"  Fold {r['fold']}: {r['best']:.3f}")
    
    print(f'\nTotal time: {time.time() - time_s:.1f}s')
    print('Done.')
