import os
import numpy as np
import random
import argparse
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import ConcatDataset, random_split, DataLoader
from tqdm import tqdm
import yaml
import argparse
import wandb
import copy
import torchvision

from utils import load_divdis_model_from_path, m_f, set_seeds, replace_task_labels
from data import get_dataset, WrappedDataLoader

parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--task_size', default=2, type=int)  # TODO
parser.add_argument("--batch_size_train", type=int, default=128)
parser.add_argument("--batch_size_eval", type=int, default=512)
parser.add_argument("--train_epochs", type=int, default=60)
parser.add_argument("--val_ratio", type=float, default=0.1)
parser.add_argument("--eval_every_n_epochs", type=int, default=10)
parser.add_argument('--task_type', type=str, default="discovered")  # TODO
parser.add_argument('--task_idx', type=int, default=0)  # TODO: only for discovered tasks
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument('--l2_reg', default=0.0005, type=float)
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--ckpt_method', type=str, default='D-BAT')  # TODO
parser.add_argument('--pretrained', action="store_true", default=False)
parser.add_argument("--ckpt_path", type=str, default="../ckpts/")
parser.add_argument("--ckpt_name", type=str, default=None)  # TODO
parser.add_argument("--dataset", type=str, default="camelyon17")  # or waterbird
parser.add_argument("--data_dir", type=str, default="YOUR DIR")  # TODO
parser.add_argument("--n_classes", type=int, default=2)
parser.add_argument("--opt", type=str, default="sgd")
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--perturb_type', default='ood_is_test', choices=['ood_is_test', 'ood_is_not_test'])
parser.add_argument("--notes_name", type=str, default="na")

## Logging params
parser.add_argument('--group', type=str, default='d-bat_divdis_AS_waterbird')
parser.add_argument('--notes', type=str, default='')
parser.add_argument('--tags', type=str, nargs='*', default=[])
parser.add_argument('--project_name', type=str, default='AS-Eval')
parser.add_argument('--entity', type=str, default='task-discovery')
parser.add_argument('--nologger', action='store_true', default=False)
parser.add_argument('--resume_id', default="")
parser.add_argument('--no_resume', dest='resume', action='store_false', default=True)

args = parser.parse_args()
set_seeds(args.seed)

# Wandb logger
exp_name = f"{args.dataset}_{args.task_type}_{args.ckpt_method}_{args.notes_name}_seed{args.seed}_task{args.task_idx}_pretrained={args.pretrained}_{args.model}"
if not args.nologger:
    wandb.login()
    wandb.init(
        name=exp_name,
        project=args.project_name,
        entity=args.entity,
        tags= args.tags,
        group=args.group,
        notes=args.notes,
        id = args.resume_id if args.resume else None
    )
    ## saving hyperparameters
    wandb.config = vars(args)
    logs = {}
    logs["acc_agreement"] = None
    logs["final_acc_agreement"] = None
else:
    logs = None

# Get dataset, make sure the shuffle is disabled.
data_train, data_valid, data_test, data_perturb = get_dataset(args)
train_dl = DataLoader(data_train, batch_size=args.batch_size_eval, shuffle=False, num_workers=8, pin_memory=True)
perturb_dl = DataLoader(data_perturb, batch_size=args.batch_size_eval, shuffle=False, num_workers=8, pin_memory=True)
# train_dl = WrappedDataLoader(train_dl, lambda x, y, meta: (x.to(args.device), y.to(args.device)))
# perturb_dl = WrappedDataLoader(perturb_dl, lambda x, y, meta: (x.to(args.device), y.to(args.device)))
train_dl = WrappedDataLoader(train_dl, lambda x, y, meta: (x, y))
perturb_dl = WrappedDataLoader(perturb_dl, lambda x, y, meta: (x, y))

# Define the task's ckpt
if args.ckpt_name is not None:
    if args.ckpt_method == "D-BAT":
        model_state = torch.load(os.path.join(args.ckpt_path, args.ckpt_name))['ensemble']
        model_state = copy.deepcopy(model_state[args.task_idx])
        task_ckpt = m_f(pretrained=False)
        task_ckpt.load_state_dict(model_state)
        task_ckpt.eval()
    elif args.ckpt_method == "DivDis":
        # task_ckpt = load_divdis_model_from_path(os.path.join(args.ckpt_path, args.ckpt_name), args.model)
        model = torchvision.models.resnet50(pretrained=False)
        d = model.fc.in_features
        model.fc = nn.Linear(d, 2 * args.task_size)
        task_ckpt = copy.deepcopy(model)
        task_ckpt.load_state_dict(copy.deepcopy(torch.load(os.path.join(args.ckpt_path, args.ckpt_name))).state_dict())
        task_ckpt.to(args.device)
        task_ckpt.eval()
    else:
        raise NotImplementedError
else:
    task_ckpt = m_f(pretrained=False)

# Get task labels, and replace labels with task labels
if args.task_type == "discovered":
    train_task_labels = []
    perturb_task_labels = []
    with torch.no_grad():
        print("eval on train......\n")
        for x, _ in tqdm(train_dl):
            x = x.to(args.device)
            logits = task_ckpt(x)
            # Additional logics for DivDis
            if args.ckpt_method == "DivDis":
                logits = logits[:, args.task_idx * args.n_classes : (args.task_idx + 1) * args.n_classes]
            train_task_labels.append(logits.argmax(1))
        print("eval on unlabeled......\n")
        for x, _ in tqdm(perturb_dl):
            x = x.to(args.device)
            logits = task_ckpt(x)
            # Additional logics for DivDis
            if args.ckpt_method == "DivDis":
                logits = logits[:, args.task_idx * args.n_classes : (args.task_idx + 1) * args.n_classes]
            perturb_task_labels.append(logits.argmax(1))
    train_task_labels = torch.cat(train_task_labels).cpu()
    perturb_task_labels = torch.cat(perturb_task_labels).cpu()
elif args.task_type == "random":
    train_task_labels = torch.randint(0, 2, (len(data_train),))
    perturb_task_labels = torch.randint(0, 2, (len(data_perturb),))
elif args.task_type == "semantic":
    train_task_labels, perturb_task_labels = None, None
elif args.task_type == "semantic_random":
    train_task_labels = None
    perturb_task_labels = torch.randint(0, 2, (len(data_perturb),))
# elif args.task_type == "semantic_discovered":
#     pass
else:
    raise NotImplementedError

data_train, data_perturb = replace_task_labels(
    data_name=args.dataset,
    data_train=data_train,
    data_perturb=data_perturb,
    train_task_labels=train_task_labels,
    perturb_task_labels=perturb_task_labels,
)

# Get loaders
whole_dataset = ConcatDataset([data_train, data_perturb])
val_size = int(len(whole_dataset) * args.val_ratio)
data_train_as, data_valid_as = random_split(whole_dataset, [len(whole_dataset) - val_size, val_size])
as_train_dl = DataLoader(data_train_as, batch_size=args.batch_size_train, shuffle=True, num_workers=8, pin_memory=True)
as_valid_dl = DataLoader(data_valid_as, batch_size=args.batch_size_eval, shuffle=False, num_workers=8, pin_memory=True)
as_train_dl = WrappedDataLoader(as_train_dl, lambda x, y, meta: (x, y))
as_valid_dl = WrappedDataLoader(as_valid_dl, lambda x, y, meta: (x, y))

# Define the model (either pre-trained or not)
model1 = m_f(pretrained=args.pretrained)
model2 = m_f(pretrained=args.pretrained)

# Define AS metric and the loss
criterion = nn.CrossEntropyLoss()

# Define optimizer and prepare the training
if args.opt == 'adamw':
    optimizer = torch.optim.AdamW(list(model1.parameters()) + list(model2.parameters()), lr=args.lr, weight_decay=0.05)
else:
    optimizer = torch.optim.SGD(list(model1.parameters()) + list(model2.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.l2_reg)

# Training
for epoch in tqdm(range(args.train_epochs)):
    model1.train()
    model2.train()
    for x, y in tqdm(as_train_dl):
        optimizer.zero_grad()
        x, y = x.to(args.device), y.to(args.device)
        pred1 = model1(x)
        pred2 = model2(x)
        loss = criterion(pred1, y) + criterion(pred2, y)
        loss.backward()
        optimizer.step()

    if args.eval_every_n_epochs > 0 and (not args.nologger) and epoch % args.eval_every_n_epochs == 0:
        # Eval AS
        preds1 = []
        preds2 = []

        with torch.no_grad():
            model1.eval()
            model2.eval()
            for x, _ in tqdm(as_valid_dl):
                x = x.to(args.device)
                pred1 = model1(x).argmax(1)
                pred2 = model2(x).argmax(1)
                preds1.append(pred1)
                preds2.append(pred2)
        t_preds1 = torch.cat(preds1)
        t_preds2 = torch.cat(preds2)

        as_score = (t_preds1 == t_preds2).float().mean().item()
        logs["acc_agreement"] = as_score
        print(f"epoch {epoch}'s acc greement:", as_score)
        wandb.log(logs)

# Eval AS
preds1 = []
preds2 = []

with torch.no_grad():
    model1.eval()
    model2.eval()
    for x, _ in tqdm(as_valid_dl):
        x = x.to(args.device)
        pred1 = model1(x).argmax(1)
        pred2 = model2(x).argmax(1)
        preds1.append(pred1)
        preds2.append(pred2)

t_preds1 = torch.cat(preds1)
t_preds2 = torch.cat(preds2)

as_score = (t_preds1 == t_preds2).float().mean().item()

print("final acc greement:", as_score)
logs["final_acc_agreement"] = as_score

if not args.nologger:
    wandb.log(logs)
