import os
from datasets.celebA import CelebA
from datasets.waterbirds import Waterbirds
from datasets.ImageNetA import ImageNetA
from datasets.imagenet_val_subset import ImageNetValSubset
from datasets.BAR import BAR
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models import swin_v2_b, Swin_V2_B_Weights
import torch
from torch.optim import SGD
from torch.utils import data
from torch.nn.functional import one_hot
import os
import torchvision
from torch.utils import data
import random
import numpy as np

def set_seed(seed):
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


class Hook:
    """Registers a hook at a specific layer of a network"""

    def __init__(self, module, backward=False):
        if backward == False:
            self.hook = module[1].register_forward_hook(self.hook_fn)
            self.name = module[0]
        else:
            self.hook = module[1].register_backward_hook(self.hook_fn)
            self.name = module[0]

    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output

    def close(self):
        self.hook.remove()


class VanillaModels:

    PATH_TO_MODELS = os.path.join(os.path.curdir, "data", "saved_models")
    
    def WaterbirdsModel(setup_test: bool = False, model_name="resnet_50"):
        device = "cuda:0"
        batch_size = 128
        num_classes = 2
        num_bias_attributes = 1
        num_biases = 2
        rho = 0.95
        seed=0
        dataset_args = {"name":"waterbirds", "args": {"rho": rho}}
        model_args = {
            "model"         : model_name, 
            "opt"           : "torch.optim.SGD",
            "lr"            : 0.001,
            "momentum"      : 0.9,
            "weight_decay"  : 0,
            "bsize"         : batch_size,
            "seed"          : seed,
            "ablation"      : False
        }
        
        model_name = f"{dataset_args['name']}"
        
        for i, (key, value) in enumerate(dataset_args["args"].items()):
            model_name = f"{model_name}-{key}_{value}"
        
        for i, (key, value) in enumerate(model_args.items()):
            model_name = f"{model_name}-{key}_{value}"
        print(model_name)

        dataset = Waterbirds(env="train", return_index=True)
        if setup_test:
            test_set = Waterbirds(env="test", return_index=True)
            test_loader = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

        match model_args["model"]:
            case "resnet50":
                nb_features = 2048
                model = resnet50(weights=ResNet50_Weights.DEFAULT)
                model.avgpool = torch.nn.Sequential(model.avgpool, torch.nn.Identity())
                bottleneck = Hook(model.avgpool, backward=False)
                model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
                model = model.to(device)                
            case "vitb16":
                nb_features = 768
                model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
                model.encoder.ln = torch.nn.Sequential(model.encoder.ln, torch.nn.Identity())
                bottleneck = Hook(model.encoder.ln, backward=False)
                model.heads = torch.nn.Linear(nb_features, num_classes)
                model = model.to(device)
                

        criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
        optimizer = SGD(model.parameters(), lr=model_args["lr"], momentum=model_args["momentum"], weight_decay=model_args["weight_decay"])

        train_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

        build_dict = {
            "model"         : model, 
            "bottleneck"    : bottleneck,
            "nb_features"   : nb_features,
            "criterion"     : criterion,
            "optimizer"     : optimizer,
            "dataset"       : dataset,
            "test_set"      : test_set if setup_test else "none",
            "train_loader"  : train_loader,
            "test_loader"   : test_loader if setup_test else "none",
            "device"        : device,
            "warmup"        : 50,
            "check_dist"    : 10,
            "train_epochs"  : 10 if model_name == "resnet50" else 5
        }
        
        config = {
            "batch_size"            : batch_size,
            "num_classes"           : num_classes,
            "num_bias_attributes"   : num_bias_attributes,
            "num_biases"            : num_biases,
            "rho"                   : 0.95,
            "seed"                  : seed,
            "dataset_args"          : dataset_args,
            "model_args"            : model_args,
            "model_name"            : model_name,
            "dataset_constructor"   : str(type(dataset)),
            "model_name"            : model_name,
            "build_dict"            : build_dict
        }

        return (
            build_dict,   
            config
        )

    def CelebAModel(setup_test: bool = False):
        device = "cuda:0"
        batch_size = 256
        num_classes = 2
        num_bias_attributes = 1
        num_biases = 2
        nb_features = 512
        rho = 0.95
        seed = 0
        dataset_args = {"name":"celeba", "args": {"rho": rho}}
        model_args = {
            "model"         : "resnet18",
            "opt"           : "torch.optim.Adam",
            "lr"            : 1e-4,
            "momentum"      : 0.9,
            "weight_decay"  : 1e-4,
            "bsize"         : batch_size,
            "seed"          : seed
        }
        
        model_name = f"{dataset_args['name']}"
        
        for i, (key, value) in enumerate(dataset_args["args"].items()):
            model_name = f"{model_name}-{key}_{value}"
        
        for i, (key, value) in enumerate(model_args.items()):
            model_name = f"{model_name}-{key}_{value}"
        print(model_name)

        dataset = CelebA(root="./data", split="train")
        if setup_test:
            test_set = CelebA(root="./data", split="test")
            test_loader = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

        model = resnet18(weights=ResNet18_Weights.DEFAULT)
        model.avgpool = torch.nn.Sequential(model.avgpool, torch.nn.Identity())
        bottleneck = Hook(model.avgpool, backward=False)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        model = model.to(device)
        criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=model_args["lr"], weight_decay=model_args["weight_decay"])

        train_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

        build_dict = {
            "model"         : model, 
            "bottleneck"    : bottleneck,
            "nb_features"   : nb_features,
            "criterion"     : criterion,
            "optimizer"     : optimizer,
            "dataset"       : dataset,
            "test_set"      : test_set if setup_test else "none",
            "train_loader"  : train_loader,
            "test_loader"   : test_loader if setup_test else "none",
            "device"        : device,
            "warmup"        : 500,
            "check_dist"    : 25,
            "train_epochs"  : 150
        }
        
        config = {
            "batch_size"            : batch_size,
            "num_classes"           : num_classes,
            "num_bias_attributes"   : num_bias_attributes,
            "num_biases"            : num_biases,
            "rho"                   : "none",
            "seed"                  : seed,
            "dataset_args"          : dataset_args,
            "model_args"            : model_args,
            "model_name"            : model_name,
            "dataset_constructor"   : str(type(dataset)),
            "model_name"            : model_name,
            "build_dict"            : build_dict
        }

        return (
            build_dict,   
            config
        )

    def BARModel(setup_test: bool = False):
        device = "cuda:0"
        batch_size = 256
        num_classes = 6
        num_bias_attributes = 1
        num_biases = 6
        nb_features = 512
        rho = 0.95
        seed = 0
        dataset_args = {"name":"bar", "args": {"rho": rho}}
        model_args = {
            "model"         : "resnet18",
            "nb_features"   : nb_features,
            "opt"           : "torch.optim.SGD",
            "lr"            : 0.001,
            "momentum"      : 0.9,
            "weight_decay"  : 1e-4,
            "bsize"         : batch_size,
            "seed"          : seed
        }
        
        model_name = f"{dataset_args['name']}"
        
        for i, (key, value) in enumerate(dataset_args["args"].items()):
            model_name = f"{model_name}-{key}_{value}"
        
        for i, (key, value) in enumerate(model_args.items()):
            model_name = f"{model_name}-{key}_{value}"
        print(model_name)

        dataset = BAR(env="train", return_index=True)
        if setup_test:
            test_set = BAR(env="test", return_index=True)
            test_loader = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

        model = resnet18(weights=ResNet18_Weights.DEFAULT)
        model.avgpool = torch.nn.Sequential(model.avgpool, torch.nn.Identity())
        bottleneck = Hook(model.avgpool, backward=False)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        model = model.to(device)
        criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=model_args["lr"], weight_decay=model_args["weight_decay"])
        train_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

        build_dict = {
            "model"         : model, 
            "bottleneck"    : bottleneck,
            "nb_features"   : nb_features,
            "criterion"     : criterion,
            "optimizer"     : optimizer,
            "dataset"       : dataset,
            "test_set"      : test_set if setup_test else "none",
            "train_loader"  : train_loader,
            "test_loader"   : test_loader if setup_test else "none",
            "device"        : device,
            "warmup"        : 0,
            "check_dist"    : 5,
            "train_epochs"  : 100
        }
        
        config = {
            "batch_size"            : batch_size,
            "num_classes"           : num_classes,
            "num_bias_attributes"   : num_bias_attributes,
            "num_biases"            : num_biases,
            "rho"                   : 0.95,
            "seed"                  : seed,
            "dataset_args"          : dataset_args,
            "model_args"            : model_args,
            "model_name"            : model_name,
            "dataset_constructor"   : str(type(dataset)),
            "model_name"            : model_name,
            "build_dict"            : build_dict
        }

        return (
            build_dict,   
            config
        )
        
    def ImageNetAModel(setup_test: bool = False, model_name="resnet50"):
        device = "cuda:0"
        batch_size = 128
        num_classes = 1000
        seed = 0
        dataset_args = {"name":"ImageNet-A", "args": {"rho": "none"}}
        model_args = {
            "model"         : model_name,
            "opt"           : "torch.optim.Adam",
            "lr"            : 0.0001,
            "momentum"      : 0.9,
            "weight_decay"  : 1e-4,
            "bsize"         : batch_size,
            "seed"          : seed
        }
        
        model_name = f"{dataset_args['name']}"
        
        for i, (key, value) in enumerate(dataset_args["args"].items()):
            model_name = f"{model_name}-{key}_{value}"
        
        for i, (key, value) in enumerate(model_args.items()):
            model_name = f"{model_name}-{key}_{value}"
        print(model_name)

        
        adversarial = ImageNetA()
        unbiased_dataset = ImageNetValSubset(
            root="./data", 
            classes_subset=adversarial.dataset.classes, 
            external_cls_to_idx=None,
            pick_ratio=1.0
        )
        adversarial.external_cls_to_idx = unbiased_dataset.cls_to_idx
        
        num_bias_attributes = 1
        num_biases = num_classes

        dataset = data.ConcatDataset((unbiased_dataset, adversarial))

        match model_args["model"]:
            case "resnet50":
                nb_features = 2048
                model = resnet50(weights=ResNet50_Weights.DEFAULT)
                model.avgpool = torch.nn.Sequential(model.avgpool, torch.nn.Identity())
                bottleneck = Hook(model.avgpool, backward=False)
                model = model.to(device)                
            case "vitb16":
                nb_features = 768
                model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
                model.encoder.ln = torch.nn.Sequential(model.encoder.ln, torch.nn.Identity())
                bottleneck = Hook(model.encoder.ln, backward=False)
                model = model.to(device)
            case "swinv2b":
                nb_features = 1024
                model = swin_v2_b(weights=Swin_V2_B_Weights.DEFAULT)
                model.avgpool = torch.nn.Sequential(model.avgpool, torch.nn.Identity())
                bottleneck = Hook(model.avgpool, backward=False)
                model = model.to(device)


        criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=model_args["lr"], weight_decay=model_args["weight_decay"])
        train_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

        build_dict = {
            "model"         : model, 
            "bottleneck"    : bottleneck,
            "nb_features"   : nb_features,
            "criterion"     : criterion,
            "optimizer"     : optimizer,
            "dataset"       : dataset,
            "test_set"      : dataset if setup_test else "none",
            "train_loader"  : train_loader,
            "test_loader"   : train_loader if setup_test else "none",
            "device"        : device,
            "warmup"        : 0,
            "check_dist"    : 25,
            "train_epochs"  : 0
        }
        
        config = {
            "batch_size"            : batch_size,
            "num_classes"           : num_classes,
            "num_bias_attributes"   : num_bias_attributes,
            "num_biases"            : num_biases,
            "env"                   : "none",
            "rho"                   : "none",
            "seed"                  : seed,
            "dataset_args"          : dataset_args,
            "model_args"            : model_args,
            "model_name"            : model_name,
            "dataset_constructor"   : str(type(dataset)),
            "model_name"            : model_name,
            "build_dict"            : build_dict
        }

        return (
            build_dict,   
            config
        )
    