# -*- coding: utf-8 -*-
"""

"""

import os
import argparse
from importlib import reload 
import logging
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


# -- Dataset splitting indices --
def split_indices(labels, train_ratio, val_ratio, seed):
    np.random.seed(seed)
    classes = np.unique(labels)
    train_idx, val_idx, test_idx = [], [], []
    for c in classes:
        idx = np.where(labels == c)[0]
        np.random.shuffle(idx)
        n_train = int(len(idx) * train_ratio)
        n_val = int(len(idx) * val_ratio)
        train_idx += idx[:n_train].tolist()
        val_idx += idx[n_train:n_train + n_val].tolist()
        test_idx += idx[n_train + n_val:].tolist()
    return np.array(train_idx), np.array(val_idx), np.array(test_idx)

# -- Sample N-shot support set --
def sample_support(labels, shots, seed):
    np.random.seed(seed)
    support_idx = []
    classes = np.unique(labels)
    for c in classes:
        idx = np.where(labels == c)[0]
        np.random.shuffle(idx)
        support_idx += idx[:shots].tolist()
    return np.array(support_idx)

    
    
class NoiseMLPAdapter(nn.Module):
    def __init__(self, feat_dim, noise_dim, hidden_dim=256, num_classes= 11):
        super().__init__()
        self.noise_mlp = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, feat_dim),
            nn.Sigmoid()  
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, feat, noise):
        gate = self.noise_mlp(noise)
        modulated_feat = feat * gate + feat
        return self.classifier(modulated_feat)
    

def run_noise_adapter_train(support_feats, support_noise, support_labels,
                            val_feats, val_noise, val_labels,
                            test_feats, test_noise, test_labels,
                            adapter_module, epochs=10, batch_size = 32, lr=1e-3, save_dir='./results'):

    logging.info("\n-------- Training noise-aware adapter on the support set. --------")
    device = support_feats.device
    adapter_module = adapter_module.to(device)

    # optimizer and scheduler
    train_loader = DataLoader(
        TensorDataset(torch.cat([support_feats, support_noise], dim=1), support_labels),
        batch_size=batch_size, shuffle=True
    )
    
    
    optimizer = torch.optim.Adam(adapter_module.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * len(train_loader))

    best_val_acc = 0.0
    for epoch in range(epochs):
        adapter_module.train()
        
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            x_feats = x[:,:-2]
            x_noise = x[:,-2:]
            logits = adapter_module(x_feats, x_noise)
            loss = F.cross_entropy(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
    
        adapter_module.eval()
        with torch.no_grad():
            val_logits = adapter_module(val_feats, val_noise)
            val_pred = val_logits.argmax(dim=-1)
            val_acc = (val_pred == val_labels).float().mean().item() * 100
    
        logging.info(f"Epoch {epoch+1}: Val Acc = {val_acc:.2f}%")
    
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'adapter': adapter_module.state_dict(),
                'val_acc': val_acc
            }, os.path.join(save_dir, 'best_adapter.pth'))

    logging.info("\n-------- Evaluating on the test set. --------")
    checkpoint = torch.load(os.path.join(save_dir, 'best_adapter.pth'))
    adapter_module.load_state_dict(checkpoint['adapter'])
    adapter_module.eval()
    with torch.no_grad():
        test_logits = adapter_module(test_feats, test_noise)
        test_pred = test_logits.argmax(dim=-1)
        test_acc = (test_pred == test_labels).float().mean().item() * 100

    logging.info(f"**** Noise-Aware Adapter Test Accuracy: {test_acc:.2f}% ****")
    np.save(os.path.join(save_dir,'acc.npy'), np.array([test_acc]))
    np.save(os.path.join(save_dir, 'test_pred.npy'), test_pred.cpu().numpy() )

        
def simulate_realistic_augmented_feat(h, g, g2, alpha=0.1):
    
    with torch.no_grad():
        weight = (1 - g).clamp(min=0.0)
        weight = torch.from_numpy(weight).cuda()
        noise_std = alpha * weight
    
        noise = torch.randn_like(h) * noise_std
        h_noisy = h + noise
        h_aug = (1 + g2) * h_noisy
    return h_aug
    
        
def run_noise_adapter_feasaug_trainv2(support_feats, support_labels,
                            val_feats, val_labels,
                            test_feats, test_labels, sup_gate, sup_feats0, best_val_acc_before, best_test_acc_before,
                            classifier, epochs=10, batch_size = 32, lr=1e-3, save_dir='./results'):
    
    idx_gate_sample = [i for i in range(sup_gate.size(0))]
    
    logging.info("\n-------- Training noise-aware adapter on the support set. --------")
    device = support_feats.device
    classifier = classifier.to(device)

    # optimizer and scheduler
    train_loader = DataLoader(
        TensorDataset(torch.cat([sup_feats0, support_feats], dim=1), support_labels),
        batch_size=batch_size, shuffle=True
    )
    
    
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * len(train_loader))
    
    totloss = []  
    best_val_acc = 0.0
    
    gmean = sup_gate.mean(dim=0)
    for epoch in range(epochs):
        classifier.train()
        
        for x_all, y in train_loader:
            x_all, y = x_all.to(device), y.to(device)
            x = x_all[:,512:]
            x0 = x_all[:,:512]
            
            if len(totloss) > 10:
                                
                gate_idx = np.random.choice(idx_gate_sample, size=x.size(0)*2, replace=True)
                noisy_x = simulate_realistic_augmented_feat(x, gmean, sup_gate[gate_idx[x.size(0):],:], alpha=0.025)
                
                combined_x = torch.cat([x, noisy_x], dim=0)
                combined_y = torch.cat([y, y], dim=0)  # 使用相同的标签
                
                logits = classifier(combined_x)
                loss = F.cross_entropy(logits, combined_y)
            else:
                logits = classifier(x)
                loss = F.cross_entropy(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            totloss.append(loss.item())
            
            
        classifier.eval()
        with torch.no_grad():
            val_logits = classifier(val_feats)
            val_pred = val_logits.argmax(dim=-1)
            val_acc = (val_pred == val_labels).float().mean().item() * 100

        logging.info(f"Epoch {epoch+1}: Val Acc = {val_acc:.2f}%")
    
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'classifier': classifier.state_dict(),
                'val_acc': val_acc
            }, os.path.join(save_dir, 'best_cls.pth'))


    logging.info("\n-------- Evaluating on the test set. --------")
    checkpoint = torch.load(os.path.join(save_dir, 'best_cls.pth'))
    classifier.load_state_dict(checkpoint['classifier'])
    classifier.eval()
    
    with torch.no_grad():
        test_logits = classifier(test_feats)
        test_pred = test_logits.argmax(dim=-1)
        test_acc = (test_pred == test_labels).float().mean().item() * 100

    logging.info(f"**** Noise-Aware Adapter Test Accuracy: {test_acc:.2f}% ****")
    
    if best_val_acc >= best_val_acc_before:#True:#
        np.save(os.path.join(save_dir,'acc.npy'), np.array([test_acc]))
        np.save(os.path.join(save_dir, 'test_pred.npy'), test_pred.cpu().numpy() )

def main(args):
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    device = torch.device(args.device)
    # torch.backends.cudnn.deterministic = True
    
    exp = f"{args.method}_alpha0020_shots{args.shots}_seed{args.seed}"
    save_dir = os.path.join(args.output_dir, exp)
    os.makedirs(save_dir, exist_ok=True)
    reload(logging)
    
    feats = np.load(args.image_feat_path)
    
    if args.use_noise:
        noise = np.load(args.noise_path)
        feats = np.hstack((feats, noise))
    
    labels = np.load(args.label_path)
    tr, vl, te = split_indices(labels, args.train_ratio, args.val_ratio, args.seed)
    np.save(os.path.join(save_dir, 'tr_idx.npy'), tr)
    np.save(os.path.join(save_dir, 'vl_idx.npy'), vl)
    np.save(os.path.join(save_dir, 'te_idx.npy'), te)
    
    all_feats = torch.from_numpy(feats).float().to(device)
    all_labels = torch.from_numpy(labels).long().to(device)
    all_feats /= all_feats.norm(dim=-1, keepdim=True)
    
    sup_idx = sample_support(all_labels[tr].cpu().numpy(), args.shots, args.seed)
    sup_feats = all_feats[tr][sup_idx]
    sup_labels = all_labels[tr][sup_idx]
    val_feats, val_labels = all_feats[vl], all_labels[vl]
    test_feats, test_labels = all_feats[te], all_labels[te]
    # clip_w = F.normalize(torch.from_numpy(np.load(args.text_feat_path)).float().to(device), -1)
    clip_w = torch.from_numpy(np.load(args.text_feat_path)).float().to(device)
    clip_w /= clip_w.norm(dim=0, keepdim=True)
    
    
    noise = np.load(args.noise_path)
    noise = torch.from_numpy(noise)
    mean = noise.mean(dim=0, keepdim=True)
    std = noise.std(dim=0, keepdim=True) + 1e-6
    noise = (noise - mean) / std
    support_noise, val_noise, test_noise = noise[tr][sup_idx].float().to(device), noise[vl].float().to(device), noise[te].float().to(device)
    
    
    if args.method == 'noise_adapter_feasaug':
        adapter = NoiseMLPAdapter(sup_feats.size(1), 2, hidden_dim=args.hidden_dim, num_classes= int(labels.max()) + 1)  
        run_noise_adapter_train(sup_feats, support_noise, sup_labels,
                                    val_feats, val_noise, val_labels,
                                    test_feats, test_noise, test_labels,
                                    adapter, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, save_dir=save_dir)
        
        checkpoint = torch.load(os.path.join(save_dir, 'best_adapter.pth'))
        best_val_acc_before = checkpoint['val_acc']
        best_test_acc_before = np.load( os.path.join(save_dir, 'acc.npy') )
        adapter.load_state_dict(checkpoint['adapter'])
        adapter.eval()

        gate = adapter.noise_mlp(support_noise)#adapter.noise_mlp[:-1](support_noise)  
        g1 = gate.detach().cpu().numpy()

        g1mean = torch.from_numpy(g1.mean(axis=0))
        sensitivity = torch.ones(g1mean.shape) *1#0.001
        

        sensitivity = sensitivity.cuda() 
        
        with torch.no_grad():
            gate = adapter.noise_mlp(noise.float().to(device))
            modulated_feat = all_feats * gate + all_feats

        sup_feats0 = sup_feats.clone()
        sup_feats = modulated_feat[tr][sup_idx]
        sup_gate = gate[tr][sup_idx]
        val_feats = modulated_feat[vl]
        test_feats = modulated_feat[te]
        
        classifier = nn.Sequential(
            nn.Linear(sup_feats.size(1), args.hidden_dim),
            nn.BatchNorm1d(args.hidden_dim),
            nn.ReLU(),
            nn.Linear(args.hidden_dim, int(labels.max()) + 1)
        )

        run_noise_adapter_feasaug_trainv2(sup_feats, sup_labels,
                                    val_feats, val_labels,
                                    test_feats, test_labels, gate, sup_feats0, best_val_acc_before, best_test_acc_before,
                                    classifier, epochs=args.epochs, batch_size = args.batch_size, lr=args.lr, save_dir=save_dir)
        
        



p = argparse.ArgumentParser()


p.add_argument('--image_feat_path', default='.\totfeas.npy', type=str)
p.add_argument('--noise_path', default='.\noise.npy', type=str)
p.add_argument('--use_noise', default=False, type=bool)
p.add_argument('--label_path', default='.\totlabels.npy', type=str)
p.add_argument('--text_feat_path',default='.\textfeas.npy', type=str)


p.add_argument('--output_dir', default='.\experiments')
p.add_argument('--method', default='noise_adapter_feasaug', type=str, help = ['noise_adapter_feasaug','tip_adapter','tip_adapter_f'])  
p.add_argument('--train_ratio', type=float, default=0.6)
p.add_argument('--val_ratio', type=float, default=0.1)
p.add_argument('--shots', default=1, type=int)
p.add_argument('--seed', type=int, default=1)
p.add_argument('--device', default='cuda')
# model args
p.add_argument('--hidden_dim', type=int, default=512)
p.add_argument('--alpha', type=float, default=0.5)
p.add_argument('--beta', type=float, default=0)
p.add_argument('--clip_scale', type=float, default=100.0)
p.add_argument('--lr', type=float, default=0.001)
p.add_argument('--epochs', type=int, default=100)
p.add_argument('--batch_size', type=int, default=64)

args = p.parse_args()

methods = ['noise_adapter_feasaug']
shots = [1,2,4,8,16]
num_exp = 5
for im in methods:
    for ishot in shots:
        args.method = im
        args.shots = ishot
        print(im, ishot)
        for iexp in range(num_exp):
            args.seed = iexp + 1
            main(args)