import os
import torch
import pickle
import argparse
from Experiments.semi_discrete_OT import *
from Experiments.scores import *
import os.path as osp
from Experiments.utils import *
from data_preprocess.VisDA_preprocess import data_load

# Command setting
parser = argparse.ArgumentParser(description='OT_score')

#arguments for loading models
parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
parser.add_argument('--bottleneck', type=int, default=256)
parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epsilon', type=float, default=0.5)
parser.add_argument('--algorithm', type=str, default='Unknown')
parser.add_argument('--src_mean_feture', type=bool, default=True)

parser.add_argument('--data', type=str, default='VisDA_17')
parser.add_argument('--feature_loading', type=str, default="target", choices=["tar", "src"])
parser.add_argument('--feature_folder', type=str, default="F:\OT_Score/feature_extractor\save_features")


parser.add_argument('--seed', type=int, default=2020, help="random seed")
parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")

parser.add_argument('--s', type=int, default=0, help="source") 
parser.add_argument('--t', type=int, default=1, help="target")
parser.add_argument('--worker', type=int, default=4, help="number of workers")

parser.add_argument('--src', type=str, default='webcam')
parser.add_argument('--tar', type=str, default='amazon')
parser.add_argument('--da', type=str, default='uda')


args = parser.parse_args()

def load_feature_info(path, device='cuda'):
    with open(path, 'rb') as f:
        data = pickle.load(f)

    def process(x):
        if isinstance(x, list) and isinstance(x[0], torch.Tensor):
            return torch.stack(x).to(device)
        elif isinstance(x, list):
            return torch.tensor(x).to(device)
        elif isinstance(x, torch.Tensor):
            return x.to(device)
        else:
            return x


    src_features = process(data['src_features'])
    src_probs = process(data['src_probs'])
    src_labels = process(data['src_labels']).long()

    tar_features = process(data['tar_features'])
    tar_probs = process(data['tar_probs'])
    tar_labels = process(data['tar_labels']).long()
    tar_pred_labels = process(data['tar_pred_labels']).long()

    print("✔ All features loaded.")
    print(f"src_features: {src_features.shape}, src_probs: {src_probs.shape}, src_labels: {src_labels.shape}")
    print(f"tar_features: {tar_features.shape}, tar_probs: {tar_probs.shape}, tar_labels: {tar_labels.shape}, tar_pred_labels: {tar_pred_labels.shape}")

    return src_features, src_probs, src_labels, tar_features, tar_probs, tar_labels, tar_pred_labels

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    folder = "G:\datasets"

    if args.data == "VisDA_17":
        names = ['train', 'validation', 'image']
        args.class_num = 12
    if args.data == "office31":
        names = ['train', 'validation', 'image']
        args.class_num = 31
    if args.data == "office31":
        names = ['train', 'validation', 'image']
        args.class_num = 65

    print("successfully load the model")

    domain = {'src': str(args.s), 'tar': str(args.t)}
    
    feature_save_path = os.path.join(args.feature_folder, f"{args.src}2{args.tar}_features.pkl")

    (src_feature, src_prob, src_true_label, tar_feature,
     tar_prob, tar_true_label, tar_pred_label) = load_feature_info(feature_save_path, device)
    src_pred_label = torch.argmax(src_prob, dim=1).squeeze()
    tar_true_label = tar_true_label.squeeze()
    src_true_label = src_true_label.squeeze()
    print("accuracy on target domain:", compute_accuracy(tar_true_label.to(device), tar_pred_label.to(device)))
    print("accuracy on src domain:", compute_accuracy(src_true_label.to(device), src_pred_label.to(device)))

    if args.src_mean_feture:
        src_feature, src_label = get_class_mean(src_feature, src_true_label)

    Maxprob_score = Maxprob(tar_prob)
    Ent_score = Ent(tar_prob)
    JMDS_score = JMDS(tar_feature, tar_prob)
    print(JMDS_score)

    src_prob = Probability_Measure(src_feature, weight=None, label=src_label)  # location, weight=None, label=None
    src_prob.weight = label_proportions(tar_true_label).unsqueeze(1).to(device)

    tar_prob = Probability_Measure(tar_feature, weight=None, label=tar_true_label)

    tar_prob.predicted_label = tar_pred_label

    semi_ot = Semi_Discrete_OT(tar_prob, src_prob, cost=None) #nu, mu
    semi_ot.compute_OT(lr=None, max_iter=5000, batch_size=2000, epsilon=args.epsilon)
    print(semi_ot.reweight_factors)
    semi_ot.get_weighted_distance()

    if args.algorithm == 'SDDA':
        semi_ot.classify(epsilon=args.epsilon)



    OTSCORE = OT_Score(semi_ot)
    OTSCORE.compute_ot_score()
    OTSCORE.get_min_ot_score()

    score_dict = {'Maxprob': Maxprob_score, 'Ent': Ent_score, 'JMDS': JMDS_score, 'OT_Score': OTSCORE.min_score}
    thresholds = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90]
    plot_aurc(tar_true_label, tar_pred_label, score_dict, thresholds)
