import torch
from model import *
import torch.optim as optim
from train_model import *
from util import *
from da_algo import *
from ot_util import generate_domains
from dataset import *
import clip
import copy
import argparse
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_source_model(trainset, testset, n_class, mode, encoder=None, epochs=50, verbose=True):

    if verbose: print("Start training source model")

    if encoder is not None:
        model = Classifier(encoder, MLP(mode=mode, n_class=n_class, hidden=1024)).to(device)
    else:
        # model = MLP(mode=mode, n_class=n_class, hidden=1024).to(device)
        model = FC(trainset[0][0].shape[0], n_class=n_class, hdim=1024).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4) #weight_decay=1e-4)
    trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
    testloader = DataLoader(testset, batch_size=128, shuffle=False)

    for epoch in range(1, epochs+1):
        train(epoch, trainloader, model, optimizer, verbose=verbose)
        if epoch % 10 == 0:
            test(testloader, model, verbose=verbose)

    return model


def run_mnist_experiment(target, gt_domains, generated_domains):

    src_trainset, tgt_trainset = get_single_rotate(False, 0), get_single_rotate(False, target)

    encoder = ENCODER().to(device)
    source_model = get_source_model(src_trainset, src_trainset, 10, "mnist", encoder=encoder, epochs=20)
    model_copy = copy.deepcopy(source_model)

    all_sets = []
    for i in range(1, gt_domains+1):
        all_sets.append(get_single_rotate(False, i*target//(gt_domains+1)))
    all_sets.append(tgt_trainset)

    direct_acc, st_acc = self_train(model_copy, [tgt_trainset], epochs=10)
    direct_acc_all, st_acc_all = self_train(source_model, all_sets, epochs=10)

    e_src_trainset, e_tgt_trainset = get_encoded_dataset(source_model.encoder, src_trainset), get_encoded_dataset(source_model.encoder, tgt_trainset)
    intersets = all_sets[:-1]
    encoded_intersets = [e_src_trainset]
    for i in intersets:
        encoded_intersets.append(get_encoded_dataset(source_model.encoder, i))
    encoded_intersets.append(e_tgt_trainset)

    all_domains = []
    for i in range(len(encoded_intersets)-1):
        all_domains += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1])
        
    _, generated_acc = self_train(source_model.mlp, all_domains, epochs=10)

    with open(f"logs/mnist_{target}_layer2.txt", "a") as f:
        f.write(f"seed{args.seed}with{gt_domains}gt{generated_domains}generated,{round(direct_acc.item(), 2)},{round(st_acc.item(), 2)},{round(direct_acc_all.item(), 2)},{round(st_acc_all.item(), 2)},{round(generated_acc.item(), 2)}\n")


def run_mnist_ablation(target, gt_domains, generated_domains):

    encoder = ENCODER().to(device)
    src_trainset, tgt_trainset = get_single_rotate(False, 0), get_single_rotate(False, target)
    source_model = get_source_model(src_trainset, src_trainset, 10, "mnist", encoder=encoder, epochs=20)
    model_copy = copy.deepcopy(source_model)

    all_sets = []
    for i in range(1, gt_domains+1):
        all_sets.append(get_single_rotate(False, i*target//(gt_domains+1)))
        print(i*target//(gt_domains+1))
    all_sets.append(tgt_trainset)

    direct_acc, st_acc = self_train(model_copy, [tgt_trainset], epochs=10)
    direct_acc_all, st_acc_all = self_train(source_model, all_sets, epochs=10)
    model_copy1 = copy.deepcopy(source_model)
    model_copy2 = copy.deepcopy(source_model)
    model_copy3 = copy.deepcopy(source_model)
    model_copy4 = copy.deepcopy(source_model)

    e_src_trainset, e_tgt_trainset = get_encoded_dataset(source_model.encoder, src_trainset), get_encoded_dataset(source_model.encoder, tgt_trainset)
    intersets = all_sets[:-1]
    encoded_intersets = [e_src_trainset]
    for i in intersets:
        encoded_intersets.append(get_encoded_dataset(source_model.encoder, i))
    encoded_intersets.append(e_tgt_trainset)

    # random plan
    all_domains1 = []
    for i in range(len(encoded_intersets)-1):
        plan = ot_ablation(len(src_trainset), "random")
        all_domains1 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1], plan=plan)
    _, generated_acc1 = self_train(model_copy1.mlp, all_domains1, epochs=10)
    
    # uniform plan
    all_domains4 = []
    for i in range(len(encoded_intersets)-1):
        plan = ot_ablation(len(src_trainset), "uniform")
        all_domains4 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1], plan=plan)
    _, generated_acc4 = self_train(model_copy4.mlp, all_domains4, epochs=10)
    
    # OT plan
    # all_domains2 = []
    # for i in range(len(encoded_intersets)-1):
    #     all_domains2 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1])
    # _, generated_acc2 = self_train(model_copy2.mlp, all_domains2, epochs=10)

    # ground-truth plan
    all_domains3 = []
    for i in range(len(encoded_intersets)-1):
        plan = np.identity(len(src_trainset))
        all_domains3 += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1])
    _, generated_acc3 = self_train(model_copy3.mlp, all_domains3, epochs=10)

    with open(f"logs/mnist_{target}_{generated_domains}_ablation.txt", "a") as f:
        f.write(f"seed{args.seed}generated{generated_domains},{round(direct_acc.item(), 2)},{round(st_acc.item(), 2)},{round(st_acc_all.item(), 2)},{round(generated_acc1.item(), 2)},{round(generated_acc4.item(), 2)},{round(generated_acc3.item(), 2)}\n")


def run_portraits_experiment(gt_domains, generated_domains, pretrain="none"):

    (src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y, dir_inter_x, dir_inter_y,
        trg_val_x, trg_val_y, trg_test_x, trg_test_y) = make_portraits_data(1000, 1000, 14000, 2000, 1000, 1000)
    tr_x, tr_y = np.concatenate([src_tr_x, src_val_x]), np.concatenate([src_tr_y, src_val_y])
    ts_x, ts_y = np.concatenate([trg_val_x, trg_test_x]), np.concatenate([trg_val_y, trg_test_y])

    if pretrain == "clip":
        encoder = model.encode_image
        transforms = preprocess
    else:
        encoder = ENCODER().to(device)
        transforms = ToTensor()

    src_trainset = EncodeDataset(tr_x, tr_y.astype(int), transforms)
    tgt_trainset = EncodeDataset(ts_x, ts_y.astype(int), transforms)
    if pretrain == "clip":
        source_model = get_source_model(get_encoded_dataset(encoder, src_trainset), get_encoded_dataset(encoder, src_trainset), 2, mode="portraits", epochs=20)
    else:
        source_model = get_source_model(src_trainset, src_trainset, 2, mode="portraits", encoder=encoder, epochs=20)
    model_copy = copy.deepcopy(source_model)

    def get_domains(n_domains):
        n2idx = {0:[], 1:[3], 2:[2,4], 3:[1,3,5], 4:[0,2,4,6]}
        domain_set = []
        domain_idx = n2idx[n_domains]
        # for i in range(1, n_domains+1):
        #     domain_idx.append(7 // (n_domains+1) * i)
        print(domain_idx)
        for i in domain_idx:
            start, end = i*2000, (i+1)*2000
            domain_set.append(EncodeDataset(inter_x[start:end], inter_y[start:end].astype(int), ToTensor()))
        return domain_set

    all_sets = get_domains(gt_domains)
    all_sets.append(tgt_trainset)
    
    if pretrain != "clip":
        direct_acc, st_acc = self_train(model_copy, [tgt_trainset], epochs=10)
        direct_acc_all, st_acc_all = self_train(source_model, all_sets, epochs=10)

    for param in source_model.encoder.parameters():
        param.requires_grad = False

    e_src_trainset, e_tgt_trainset = get_encoded_dataset(source_model.encoder, src_trainset), get_encoded_dataset(source_model.encoder, tgt_trainset)
    intersets = all_sets[:-1]
    encoded_intersets = [e_src_trainset]
    for i in intersets:
        encoded_intersets.append(get_encoded_dataset(source_model.encoder, i))
    encoded_intersets.append(e_tgt_trainset)

    if pretrain == "clip":
        direct_acc, st_acc = self_train(model_copy, [e_tgt_trainset], epochs=10)
        direct_acc_all, st_acc_all = self_train(source_model, encoded_intersets[1:], epochs=10)

    all_domains = []
    for i in range(len(encoded_intersets)-1):
        all_domains += generate_domains(generated_domains, encoded_intersets[i], encoded_intersets[i+1])
        
    _, generated_acc = self_train(source_model.mlp, all_domains, epochs=10)
    # _, generated_acc = self_train(source_model, all_domains, epochs=10)

    with open(f"logs/portraits_exp.txt", "a") as f:
            f.write(f"seed{args.seed}with{gt_domains}gt{generated_domains}generated,{round(direct_acc.item(), 2)},{round(st_acc.item(), 2)},{round(direct_acc_all.item(), 2)},{round(st_acc_all.item(), 2)},{round(generated_acc.item(), 2)}\n")


def main(args):

    print(args)

    if args.seed is not None:
        torch.cuda.manual_seed(args.seed)
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
    
    if args.mode == "mnist":
        if args.mnist_mode == "normal":
            run_mnist_experiment(args.rotation_angle, args.gt_domains, args.generated_domains)
        else:
           run_mnist_ablation(args.rotation_angle, args.gt_domains, args.generated_domains)
    else:
        eval(f"run_{args.mode}_experiment({args.gt_domains}, {args.generated_domains})")

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="GOAT experiments")
    parser.add_argument("mode", choices=["mnist", "cifar", "portraits", "office31", "office_home"])
    parser.add_argument("--pretrain", default="clip", choices=["imagenet", "clip", "none"])
    parser.add_argument("--model", default="RN50", choices=["RN50", "RN101", "ViT-B/32", "ViT-B/16"])
    parser.add_argument("--gt-domains", default=0, type=int)
    parser.add_argument("--generated-domains", default=0, type=int)
    parser.add_argument("--seed", default=None, type=int)
    parser.add_argument("--mnist-mode", default="normal", choices=["normal", "ablation"])
    parser.add_argument("--rotation-angle", default=45, type=int)
    parser.add_argument("--batch-size", default=128, type=int)
    args = parser.parse_args()

    if args.pretrain == "clip":
        model, preprocess = clip.load(args.model, device)
        encoder = model.encode_image

    main(args)