import os
import os.path as osp
import argparse
import random
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from rich import pretty


from domainnet_loader import DomainNetDataset  
from models.bag import BAG
from trainers.bag_trainer import BAGTrainer
from misc import initialize_torchvision_model  
pretty.install()


class ResNet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.network = model
        self.freeze_bn()

    def forward(self, x):
        return self.network(x)

    def train(self, mode=True):
        super().train(mode)
        self.freeze_bn()

    def freeze_bn(self):
        for m in self.network.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


from torch.utils.data import Dataset
class DomainSubsetByEnv(Dataset):
    def __init__(self, base_ds: DomainNetDataset, domain_id: int):
        self.base = base_ds
        self.idx = [i for i, d in enumerate(base_ds.domain_labels) if d == domain_id]

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, i):
        img, y, _d = self.base[self.idx[i]]  
        return img, y


def build_transforms(img_size=224):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])
    test_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])
    return train_tf, test_tf


DOM2ID = {
    "clipart": 0,
    "infograph": 1,
    "painting": 2,
    "quickdraw": 3,
    "real": 4,
    "sketch": 5,
}

def _read_and_tag(lines, domain_id):
    tagged = []
    for ln in lines:
        if not ln.strip():
            continue
        parts = ln.strip().split()
        if len(parts) == 2:

            tagged.append(f"{parts[0]} {int(parts[1])} {int(domain_id)}\n")
        elif len(parts) >= 3:

            tagged.append(f"{parts[0]} {int(parts[1])} {int(domain_id)}\n")
        else:
            raise ValueError(f"Bad line: {ln}")
    return tagged

def ensure_image_list_files(
    data_root: str,
    train_domains: list,
    target_domain: str,
    list_dir_name: str = "image_list"
):
    list_dir = osp.join(data_root, list_dir_name)
    os.makedirs(list_dir, exist_ok=True)

    train_out_name = f"DomainNet_train_{'-'.join(train_domains)}.txt"
    train_out_path = osp.join(list_dir, train_out_name)
    if not osp.exists(train_out_path):
        merged = []
        for d in train_domains:
            d = d.lower()
            assert d in DOM2ID, f"Unknown domain: {d}"
            src_file = osp.join(data_root, f"{d}_train.txt")
            if not osp.exists(src_file):
                raise FileNotFoundError(f"Missing train list: {src_file}")
            with open(src_file, "r") as f:
                lines = f.readlines()
            merged.extend(_read_and_tag(lines, DOM2ID[d]))
        with open(train_out_path, "w") as f:
            f.writelines(merged)

    tgt = target_domain.lower()
    assert tgt in DOM2ID, f"Unknown target domain: {tgt}"
    test_out_name = f"DomainNet_test_{tgt}.txt"
    test_out_path = osp.join(list_dir, test_out_name)
    if not osp.exists(test_out_path):
        src_file = osp.join(data_root, f"{tgt}_test.txt")
        if not osp.exists(src_file):
            raise FileNotFoundError(f"Missing test list: {src_file}")
        with open(src_file, "r") as f:
            lines = f.readlines()
        tagged = _read_and_tag(lines, DOM2ID[tgt])
        with open(test_out_path, "w") as f:
            f.writelines(tagged)

    return train_out_name, test_out_name

def main(args):
    
    set_seed(args.seed)
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    pretty.pprint(f"Using {args.device} device")

    os.makedirs(args.log, exist_ok=True)
    args.model_save_dir = args.log
    pretty.pprint(f"Model save directory: {args.model_save_dir}")

    
    train_domains = [d.strip() for d in args.train_domains.split(",") if d.strip()]
    train_list_name, test_list_name = ensure_image_list_files(
        data_root=args.data_root,
        train_domains=train_domains,
        target_domain=args.target_domain,
        list_dir_name="image_list"
    )

    
    train_tf, test_tf = build_transforms(img_size=args.img_size)

    train_base = DomainNetDataset(
        root=args.data_root,
        image_list=train_list_name,
        transform=train_tf
    )
    test_base = DomainNetDataset(
        root=args.data_root,
        image_list=test_list_name,
        transform=test_tf
    )


    out_dim = train_base.num_classes
    args.num_class = out_dim
    args.classification = True

   
    train_env_ids = [DOM2ID[d.lower()] for d in train_domains]
    target_id = DOM2ID[args.target_domain.lower()]
    pretty.pprint({"train_envs(ids)": train_env_ids, "target(id)": target_id})

    
    train_dataset_list = [DomainSubsetByEnv(train_base, d) for d in train_env_ids]
    test_dataset = DomainSubsetByEnv(test_base, target_id)

   
    args.n_envs = len(train_dataset_list)
    args.environment_num = args.n_envs
    args.torch_loader = True
    args.num_workers = args.nb_workers
    print(f"train_dataset_list: {len(train_dataset_list)},total batch size: {args.batch_size * args.n_envs}.total sample: {len(train_base)}")
    print(f"test_dataset: {len(test_dataset)},total batch size: {args.batch_size}.total sample: {len(test_base)}")
    args.model_kwargs = {"pretrained": True}
    backbone = initialize_torchvision_model(
        name="resnet50",          
        d_out=args.resnet_dim,    
        **args.model_kwargs
    )
    args.phi_odim = backbone.d_out
    Phi = ResNet(backbone)


    model = BAG(
        n_batch_envs=args.n_envs,
        input_dim=0,
        Phi=Phi,
        config=args,
        out_dim=out_dim,
        phi_dim=args.phi_odim
    ).to(args.device)


    criterion = nn.CrossEntropyLoss()
    trainer = BAGTrainer(
        model=model,
        loss_fn=criterion,
        reg_lambda=args.reg_lambda,
        config=args,
        causal_dir=False
    )

   
    best_metric = trainer.train(
        train_dataset=train_dataset_list,
        batch_size=args.batch_size,    
        test_dataset=test_dataset,
        log_dir=args.model_save_dir
    )
    pretty.pprint({"best_metric(test_tta_result_acc)": best_metric})
    return best_metric

if __name__ == "__main__":
    parser = argparse.ArgumentParser("DomainNet + BAG (ResNet-50) main")

 
    parser.add_argument("--data_root", type=str, default='./dataset/DomainNet',
                        help="")
    parser.add_argument("--train_domains", type=str, default="infograph,painting,quickdraw,sketch,real",
                        help="clipart,infograph,painting,quickdraw,sketch")
    parser.add_argument("--target_domain", type=str, default="clipart",
                        choices=list(DOM2ID.keys()),
                        help="real")

    parser.add_argument("--img_size", type=int, default=224)
    parser.add_argument("--resnet_dim", type=int, default=2048,
                        help="")
    parser.add_argument("--z_dim", type=int, default=512)
    parser.add_argument("--z_c_dim", type=int, default=256)
    parser.add_argument("--z_s_dim", type=int, default=256)
    parser.add_argument("--environment_dim", type=int, default=32)
    parser.add_argument("--hide_dim", type=int, default=1024)
    parser.add_argument('--balanced_dataset', action='store_true', help='imbalanced or balanced dataset')

    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--batch_size", type=int, default=128,
                        help="")
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--fine_tune_lr", type=float, default=1e-3,
                        help="")
    parser.add_argument("--train_z_s_epoch", type=int, default=3,
                        help="")
    parser.add_argument("--nb_workers", type=int, default=8)
    parser.add_argument("--seed", type=int, default=1)


    parser.add_argument("--reg_lambda", type=float, default=0)
    parser.add_argument("--reg_lambda_2", type=float, default=0.0)
    parser.add_argument("--gamma", type=float, default=0.0)
    parser.add_argument("--vae_lambda", type=float, default=0)
    parser.add_argument("--beta", type=float, default=1.0)
    parser.add_argument("--C_max", type=float, default=5.0)
    parser.add_argument("--C_stop_iter", type=int, default=10)


    parser.add_argument("--log", type=str, default="./log_domainnet")
    parser.add_argument("--save_test_phi", action="store_true", help="")

    args = parser.parse_args()
    main(args)
