import argparse
import json
import os
import random
import sys
import numpy as np
import PIL
import torch
import torchvision
import torch.utils.data
from tensorboard_logger import Logger
from sklearn.neural_network import MLPClassifier
from torch.utils.data import DataLoader

import hparams_registry
from dataset import datasets
from learning import algorithms
from utils import misc
from gen import *
from baselines import *
from params import *
from utils_glue import *

def mse_loss(X, weight, y):
    pred = X @ weight
    loss = torch.mean((pred - y) ** 2)
    return loss

def mseTorch(X, y, num_steps, lr, lambdda, device):
    X = torch.tensor(X.T, dtype=torch.float32).to(device)
    y = torch.tensor(y, dtype=torch.float32).to(device)
    n, p = X.shape
    weight = torch.ones(p)
    weight = weight.to(device)
    weight.requires_grad = True
    optimizer = optim.Adam([weight,], lr = lr)
    for i in range(num_steps):
        optimizer.zero_grad()
        dis = mse_loss(X, weight, y)
        # lasso
        loss = dis + lambdda * abs(weight).sum()
        if i == 0 or (i + 1) % 500 == 0:
            print(f"iter {i}  dis: {dis.item()}  loss: {loss.item()}")
        loss.backward()
        optimizer.step()
    return abs(weight).cpu().detach().numpy()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Subpopulation Shift Benchmark')
    # training
    parser.add_argument('--dataset', type=str, default="Waterbirds", choices=datasets.DATASETS)
    parser.add_argument('--algorithm', type=str, default="ERM", choices=algorithms.ALGORITHMS)
    parser.add_argument('--output_folder_name', type=str, default='debug')
    parser.add_argument('--train_attr', type=str, default="yes", choices=['yes', 'no'])
    # others
    parser.add_argument('--data_dir', type=str, default="./data")
    parser.add_argument('--output_dir', type=str, default="./output")
    parser.add_argument('--hparams', type=str, help='JSON-serialized hparams dict')
    parser.add_argument('--hparams_seed', type=int, default=0, help='Seed for random hparams (0 for "default hparams")')
    parser.add_argument('--seed', type=int, default=0, help='Seed for everything else')
    parser.add_argument('--steps', type=int, default=None)
    parser.add_argument('--tb_log_all', action='store_true')
    # two-stage related
    parser.add_argument('--stage1_folder', type=str, default='vanilla')
    parser.add_argument('--stage1_algo', type=str, default='ERM')
    # early stopping
    parser.add_argument('--use_es', action='store_true')
    parser.add_argument('--es_strategy', choices=['metric'], default='metric')
    parser.add_argument('--es_metric', type=str, default='min_group:accuracy')
    parser.add_argument('--es_patience', type=int, default=5, help='Stop after this many checkpoints w/ no improvement')
    # checkpoints
    parser.add_argument('--resume', '-r', type=str, default='')
    parser.add_argument('--pretrained', type=str, default='')
    parser.add_argument('--checkpoint_freq', type=int, default=None, help='Checkpoint every N steps')
    parser.add_argument('--skip_model_save', action='store_true')
    # CMNIST data params
    parser.add_argument('--cmnist_label_prob', type=float, default=0.5)
    parser.add_argument('--cmnist_attr_prob', type=float, default=0.5)
    parser.add_argument('--cmnist_spur_prob', type=float, default=0.2)
    parser.add_argument('--cmnist_flip_prob', type=float, default=0.25)
    # Params for SATE
    parser.add_argument('--resnet18', action='store_true', help="with this, we use resetnet18 instead of resnet50")
    parser.add_argument('--validate', action='store_true', help="with this, we has validation set and don't have to test on train set")
    parser.add_argument("--llm_augment", action='store_true')
    parser.add_argument('--sample_reweight', action='store_true', help="with this, we do sample reweight on training set")
    parser.add_argument('--fog', type=float, default=0, help="the degree of fog added to test image")
    parser.add_argument('--blur', type=float, default=0, help="the degree of blur added to test image")
    parser.add_argument('--noise', type=float, default=0, help="the degree of noise added to test image")
    parser.add_argument('--bright', type=float, default=0, help="the degree of bright added to test image")
    parser.add_argument('--contrast', type=float, default=0, help="the degree of contrast added to test image")
    # architectures and pre-training sources
    parser.add_argument('--image_arch', default='resnet_sup_in1k',
                        choices=['resnet_sup_in1k', 'resnet_sup_in21k', 'resnet_simclr_in1k', 'resnet_barlow_in1k',
                                 'vit_sup_in1k', 'vit_sup_in21k', 'vit_clip_oai', 'vit_clip_laion', 'vit_sup_swag',
                                 'vit_dino_in1k', 'resnet_dino_in1k'])
    parser.add_argument('--text_arch', default='bert-base-uncased',
                        choices=['bert-base-uncased', 'gpt2', 'xlm-roberta-base',
                                 'allenai/scibert_scivocab_uncased', 'distilbert-base-uncased'])
    args = parser.parse_args()

    start_step = 0
    store_prefix = f"{args.dataset}_{args.cmnist_label_prob}_{args.cmnist_attr_prob}_{args.cmnist_spur_prob}" \
                   f"_{args.cmnist_flip_prob}" if args.dataset == "CMNIST" else args.dataset
    args.store_name = f"{store_prefix}_{args.algorithm}_hparams{args.hparams_seed}_seed{args.seed}"
    args.output_folder_name += "_attrYes" if args.train_attr == 'yes' else "_attrNo"

    misc.prepare_folders(args)
    args.output_dir = os.path.join(args.output_dir, args.output_folder_name, args.store_name)
    sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt'))
    sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt'))

    tb_logger = Logger(logdir=args.output_dir, flush_secs=2)

    print("Environment:")
    print("\tPython: {}".format(sys.version.split(" ")[0]))
    print("\tPyTorch: {}".format(torch.__version__))
    print("\tTorchvision: {}".format(torchvision.__version__))
    print("\tCUDA: {}".format(torch.version.cuda))
    print("\tCUDNN: {}".format(torch.backends.cudnn.version()))
    print("\tNumPy: {}".format(np.__version__))
    print("\tPIL: {}".format(PIL.__version__))

    print('Args:')
    for k, v in sorted(vars(args).items()):
        print('\t{}: {}'.format(k, v))

    if args.hparams_seed == 0:
        hparams = hparams_registry.default_hparams(args.algorithm, args.dataset)
    else:
        hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, misc.seed_hash(args.hparams_seed))
    if args.hparams:
        hparams.update(json.loads(args.hparams))
    if args.dataset == "CMNIST":
        hparams.update({'cmnist_label_prob': args.cmnist_attr_prob,
                        'cmnist_attr_prob': args.cmnist_attr_prob,
                        'cmnist_spur_prob': args.cmnist_spur_prob,
                        'cmnist_flip_prob': args.cmnist_flip_prob})

    hparams.update({
        'image_arch': args.image_arch,
        'text_arch': args.text_arch,
        'resnet18': args.resnet18
    })

    print('HParams:')
    for k, v in sorted(hparams.items()):
        print('\t{}: {}'.format(k, v))

    with open(os.path.join(args.output_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    torch.multiprocessing.set_sharing_strategy('file_system')

    device = "cuda" if torch.cuda.is_available() else "cpu"

    meta_path = "0.csv"
    if args.dataset in vars(datasets):
        train_dataset = vars(datasets)[args.dataset](args.data_dir, 'tr', hparams, train_attr=args.train_attr, meta_path = meta_path)
        val_dataset = vars(datasets)[args.dataset](args.data_dir, 'va', hparams, meta_path = meta_path)
        test_dataset = vars(datasets)[args.dataset](args.data_dir, 'te', hparams, meta_path = meta_path)
    else:
        raise NotImplementedError

    num_workers = train_dataset.N_WORKERS
    input_shape = train_dataset.INPUT_SHAPE
    num_labels = train_dataset.num_labels
    num_attributes = train_dataset.num_attributes
    data_type = train_dataset.data_type
    n_steps = args.steps or train_dataset.N_STEPS
    checkpoint_freq = args.checkpoint_freq or train_dataset.CHECKPOINT_FREQ

    hparams.update({
        "steps": n_steps
    })

    if args.dataset == "MultiNLI":
        comb_list = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)]
        total_n_test = n6_test
    elif args.dataset == "SNLI":
        comb_list = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)]
        total_n_test = n6_test2
    elif args.dataset == "CheXpertNoFinding":
        comb_list = [(y, a) for y in range(2) for a in range(6)]
        total_n_test = n12_test
    else:
        comb_list = [(0,0), (0,1), (1,0), (1,1)]
        total_n_test = n4_test

    algorithm_class = algorithms.get_algorithm_class(args.algorithm)
    algorithm = algorithm_class(data_type, input_shape, num_labels, num_attributes,
                                len(train_dataset), hparams, grp_sizes=train_dataset.group_sizes)

    # load stage1 model
    args.pretrained = os.path.join(
        args.output_dir.replace(args.output_folder_name, args.stage1_folder), hparams['stage1_model']
    ).replace(args.algorithm, args.stage1_algo)
    args.pretrained = args.pretrained.replace(
        f"seed{args.pretrained[args.pretrained.find('seed') + len('seed')]}", 'seed0')
    assert os.path.isfile(args.pretrained)

    checkpoint = torch.load(args.pretrained, map_location="cpu")
    from collections import OrderedDict
    state_dict = OrderedDict()
    for k, v in checkpoint['model_dict'].items():
        if 'classifier' not in k:
            state_dict[k] = v
    algorithm.load_state_dict(state_dict, strict=False)
    print(f"===> Pretrained weights found in total: [{len(list(state_dict.keys()))}]")
    print(f"===> Pre-trained model loaded: '{args.pretrained}'")
    algorithm.to(device)
    algorithm.eval()

    # print_gpu_memory_usage("After Loading model")

    # Load training data
    s_list = [[] for _ in range(len(comb_list))]
    for item in train_dataset:
        y = item[2]
        a = item[3]
        for i, comb in enumerate(comb_list):
            if y == comb[0] and a == comb[1]:
                s_list[i].append(item[1])
                break
    loader_list = [DataLoader(s, batch_size = 256, shuffle = False) for s in s_list]

    with torch.no_grad():
        X = []
        correct, total = 0, 0
        for i, loader in enumerate(loader_list):
            Xi_feats_list = []
            for Xi in loader:
                Xi = Xi.to(device)
                Xi_feats = algorithm.return_feats(Xi).cpu()
                result = algorithm.predict(Xi).cpu()
                pred_y = torch.argmax(result, dim=1).cpu()
                true_y = comb_list[i][0]
                correct += (pred_y == true_y).sum()
                total += len(pred_y)
                Xi_feats_list.append(Xi_feats)
            Xi_feats_list = torch.cat(Xi_feats_list, dim = 0)
            X.append(torch.mean(Xi_feats_list, axis = 0))
            del Xi, Xi_feats
            torch.cuda.empty_cache()
        X = torch.stack(X)
        ori_train_acc = correct / total
        print("Original Training Overall Acc: ", ori_train_acc)
    torch.cuda.empty_cache()

    # Test on Training set, Need to do some perturbation, so reload
    if not args.validate and not args.llm_augment:
        print("Test on Training set...")
        s_list = [[] for _ in range(len(comb_list))]
        for i in range(len(train_dataset)):
            if args.dataset == "MultiNLI" or args.dataset == "SNLI":
                item = train_dataset[i]
            else:
                item = train_dataset.pertur(i)
            y = item[2]
            a = item[3]
            for j, comb in enumerate(comb_list):
                if y == comb[0] and a == comb[1]:
                    s_list[j].append(item[1])
                    break
        loader_list = [DataLoader(s, batch_size = 256, shuffle = False) for s in s_list]
        train_feats_list, correct_list = [], []
        with torch.no_grad():
            for i, loader in enumerate(loader_list):
                correct = []
                train_feats = []
                for batch_x in loader:
                    batch_x = batch_x.to(device)
                    result = algorithm.predict(batch_x)
                    batch_feats = algorithm.return_feats(batch_x).detach().cpu().numpy()
                    train_feats.append(batch_feats)
                    pred_y = torch.argmax(result, dim=1).cpu()
                    true_y = comb_list[i][0]
                    correct.append(pred_y == true_y)
                correct = np.concatenate(correct, axis = 0)
                train_feats = np.concatenate(train_feats, axis = 0)
                correct_list.append(correct)
                train_feats_list.append(train_feats)
    elif args.llm_augment:
        print("Test on llm-augmented training data...")
        train_dataset.llm_augment()
        s_list = [[] for _ in range(len(comb_list))]
        for i in range(len(train_dataset)):
            item = train_dataset[i]
            y = item[2]
            a = item[3]
            for j, comb in enumerate(comb_list):
                if y == comb[0] and a == comb[1]:
                    s_list[j].append(item[1])
                    break
        loader_list = [DataLoader(s, batch_size = 256, shuffle = False) for s in s_list]
        train_feats_list, correct_list = [], []
        with torch.no_grad():
            for i, loader in enumerate(loader_list):
                correct = []
                train_feats = []
                for batch_x in loader:
                    batch_x = batch_x.to(device)
                    result = algorithm.predict(batch_x)
                    batch_feats = algorithm.return_feats(batch_x).detach().cpu().numpy()
                    train_feats.append(batch_feats)
                    pred_y = torch.argmax(result, dim=1).cpu()
                    true_y = comb_list[i][0]
                    correct.append(pred_y == true_y)
                correct = np.concatenate(correct, axis = 0)
                train_feats = np.concatenate(train_feats, axis = 0)
                correct_list.append(correct)
                train_feats_list.append(train_feats)
    else:
        print("Test on Validation set...")
        s_list = [[] for _ in range(len(comb_list))]
        for i in range(min(len(val_dataset),100000)):
            item = val_dataset[i]
            y = item[2]
            a = item[3]
            for j, comb in enumerate(comb_list):
                if y == comb[0] and a == comb[1]:
                    s_list[j].append(item[1])
                    break
        loader_list = [DataLoader(s, batch_size = 256, shuffle = False) for s in s_list]
        train_feats_list, correct_list = [], []
        with torch.no_grad():
            for i, loader in enumerate(loader_list):
                correct = []
                train_feats = []
                for batch_x in loader:
                    torch.cuda.empty_cache()
                    batch_x = batch_x.to(device)
                    result = algorithm.predict(batch_x)
                    batch_feats = algorithm.return_feats(batch_x).detach().cpu().numpy()
                    train_feats.append(batch_feats)
                    pred_y = torch.argmax(result, dim=1).cpu()
                    true_y = comb_list[i][0]
                    correct.append(pred_y == true_y)
                correct = np.concatenate(correct, axis = 0)
                train_feats = np.concatenate(train_feats, axis = 0)
                correct_list.append(correct)
                train_feats_list.append(train_feats)

    torch.cuda.empty_cache()

    for train_feats, correct in zip(train_feats_list, correct_list):
        print("Original Subset Acc: ", correct.sum() / len(correct))

    print("Start Testing...")
    test_accs, others = [], []
    ours_list = []
    weight_list = []
    for name, n_test in enumerate(total_n_test):
        print("n_test: ", n_test)
        # Reload Test data every time
        test_dataset = vars(datasets)[args.dataset](args.data_dir, 'te', hparams, meta_path = f"{name}.csv", fog_intense = args.fog, blur_intense = args.blur, noise_intense = args.noise, bright_intense = args.bright, contrast_intense = args.contrast)
        test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
        total_correct = 0
        corrects = {key: 0 for key in comb_list}
        counts = {key: 0 for key in comb_list}
        test_feats = []
        with torch.no_grad():
            for item in test_loader:
                test_y_batch, test_a_batch, test_batch = item[2], item[3], item[1]
                test_batch = test_batch.to(device)
                result = algorithm.predict(test_batch)
                test_feat = algorithm.return_feats(test_batch).detach().cpu()
                test_feats.append(test_feat)
                pred_y = torch.argmax(result, dim=1).cpu()
                total_correct += (pred_y == test_y_batch).sum().item()

                for (y, a) in comb_list:
                    indices = (test_y_batch == y) & (test_a_batch == a)
                    if indices.sum().item() > 0:
                        corrects[(y, a)] += (pred_y[indices] == test_y_batch[indices]).sum().item()
                        counts[(y, a)] += indices.sum().item()

            test_total_acc = total_correct / len(test_dataset)
            detail_accs = {key: corrects[key] / counts[key] if counts[key] > 0 else 0.0 for key in corrects}
            print("Total:",test_total_acc, "Detail:", detail_accs)

            # Baselines
            if args.llm_augment:
                train_dataset = vars(datasets)[args.dataset](args.data_dir, 'tr', hparams, train_attr=args.train_attr, meta_path = meta_path)
            test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)
            train_loader = DataLoader(train_dataset, batch_size=512, shuffle=False)
            ATC_MC, ATC_NE, DOC, DOE = baseline(algorithm, train_loader, test_loader)

            if args.dataset != "MultiNLI" and args.dataset != "SNLI":
                mu = NeiborhoodInviriance(algorithm, test_dataset, device)
            else:
                mu = 0
            others.append([ATC_MC, ATC_NE, DOC, DOE, mu])

        # Ours
        test_feats = torch.cat(test_feats, dim = 0)
        y = torch.mean(test_feats, axis = 0)
        # weight = cosineTorch(X, y, num_steps = 1000, lr = 0.01, lambdda = 0, device = device)
        weight = mseTorch(X, y, num_steps = 1000, lr = 0.01, lambdda = 0, device = device)
        weight = weight / weight.sum()
        print(weight)
        weight_list.append(weight)
        test_accs.append(test_total_acc)
        # Calculate accuracy on training set
        train_accs = []
        domain_classifier = MLPClassifier()
        for train_feats, correct, w in zip(train_feats_list, correct_list, weight):
            train_accs.append(correct.sum() / len(correct))
            print("Original Subset Acc: ", correct.sum() / len(correct))

        train_accs = np.array(train_accs)
        ours = (weight * train_accs).sum()
        print(ours)
        ours_list.append(ours)
    
    if args.dataset == "MultiNLI" or args.dataset == "SNLI":
        output_name = f"{args.dataset}_bert"
    elif "resnet" in args.image_arch:
        output_name = f"{args.dataset}_resnet"
    elif "vit" in args.image_arch:
        output_name = f"{args.dataset}_vit"
    else:
        raise NotImplementedError
    
    if args.sample_reweight:
        output_name += "_reweight"
    
    if args.validate:
        output_name += "_val"

    if args.fog != 0:
        output_name += f"_fog{args.fog}"
    elif args.blur != 0:
        output_name += f"_blur{args.blur}"
    elif args.noise != 0:
        output_name += f"_noise{args.noise}"
    elif args.contrast != 0:
        output_name += f"_contrast{args.contrast}"
    elif args.bright != 0:
        output_name += f"_bright{args.bright}"

    with open(output_name + ".txt", "a") as f:
        for i in range(len(total_n_test)):
            if not args.validate:
                ours = ours_list[i]
            else:
                ours = ours_list[i]
            f.write(f"{others[i][0]}, {others[i][1]}, {others[i][2]}, {others[i][3]}, {others[i][4]}, {ours}, {test_accs[i]}, {ori_train_acc}\n")