from os.path import join, exists, dirname, abspath
from os import makedirs
import sys
import torch
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger

# Add parent directory to path to find teacher_student package
sys.path.insert(0, dirname(dirname(abspath(__file__))))

from teacher_student.models import get_model, initialize_weights
from teacher_student.utils import ModelClassifier, mk_fname
from teacher_student.dataloader import TaskDataModule, get_dataset
from teacher_student.training_routines import train_dual_heads
from teacher_student.task_utils import update_tasks_mnist




def train_teacher_student(config: dict, save_dir: str, data_dir: str, cluster: bool):
    
    # set seed
    seed = config["seed"]
    seed_everything(seed)
    torch.set_float32_matmul_precision("high")
    
    # configure models
    full_model = get_model(config["full_model"], config["full_kwargs"])
    teacher0_model = get_model(config["teacher0_model"], config["teacher0_kwargs"])
    teacher1_model = get_model(config["teacher1_model"], config["teacher1_kwargs"])
    student_model = get_model(config["student_model"], config["student_kwargs"])
    
    if "tasks_keys" in config["control"].keys():
        task1_labels,task2_labels = update_tasks_mnist(config["control"]["tasks_key"])
        
    else:
        task1_labels = config["task0_labels"]
        task2_labels = config["task1_labels"]
        
    # configure dataloaders
    if config.get("full_ds", False):
        dm_full = TaskDataModule(
            dataset_class = get_dataset(config["dataset"]),
            batch_size = config["BATCH_SIZE"],
            selected_labels = None,
            data_dir = data_dir,
            seed = seed,
            num_workers = 1 if cluster else 20
            )
    dm_task1 = TaskDataModule(
        dataset_class = get_dataset(config["dataset"]),
        batch_size = config["BATCH_SIZE"],
        selected_labels = task1_labels,
        data_dir = data_dir,
        seed = seed,
        num_workers = 1 if cluster else 20
        )
    dm_task2 = TaskDataModule(
        dataset_class = get_dataset(config["dataset"]),
        batch_size = config["BATCH_SIZE"],
        selected_labels = task2_labels,
        data_dir = data_dir,
        seed = seed,
        num_workers = 1 if cluster else 20,
        alpha = config["control"].get("alpha", None)
        )
    
    if config.get("full_ds", False):
        dm_full.setup()
    
    dm_task1.setup()
    dm_task2.setup()
    
    # initialize parameters to the same seed
    init_method = config.get("init_method", "kaiming")
    a = config.get("a", 1.0)
    
    if config.get("full_ds", False):
        initialize_weights(full_model, seed, init_method, a)
    
    # Initialize teachers based on whether they should have equal heads
    initialize_weights(teacher0_model, seed, init_method, a)
    
    if config["student_equal_heads"]:
        # If student heads should be equal, teachers should also have equal heads
        initialize_weights(teacher1_model, seed, init_method, a)  # Same seed as teacher0
        initialize_weights(student_model, seed + 1, init_method, a)  # Different seed for student
        print("Teachers initialized with same seed (equal heads)")
    else:
        # If student heads should be different, teachers should have different heads
        initialize_weights(teacher1_model, seed + 1, init_method, a)  # Different seed for teacher1
        initialize_weights(student_model, seed + 2, init_method, a)   # Different seed for student
        print("Teachers initialized with different seeds (different heads)")
        
    # # freeze heads (output layers)
    # if config.get("full_ds", False):
    #     for param in full_model.h.parameters():
    #         param.requires_grad = False
    
    # for param in teacher0_model.h.parameters():
    #     param.requires_grad = False
    
    # for param in teacher1_model.h.parameters():
    #     param.requires_grad = False
    
    # for param in student_model.h0.parameters():
    #     param.requires_grad = False
        
    # for param in student_model.h1.parameters():
    #     param.requires_grad = False
    
    
    # Set student heads to match teacher heads
    # Always set student head 0 to match teacher 0
    student_model.h0.load_state_dict(teacher0_model.h.state_dict())
    
    if config["student_equal_heads"]:
        # If heads should be equal, set student head 1 to also match teacher 0
        student_model.h1.load_state_dict(teacher0_model.h.state_dict())
        print("Student heads set to be equal (both matching teacher0)")
    else:
        # Otherwise, set student head 1 to match teacher 1
        student_model.h1.load_state_dict(teacher1_model.h.state_dict())
        print("Student heads set independently (h0=teacher0, h1=teacher1)")
    
    
    # Optional: Verify the assignment worked correctly
    if config.get("verify_head_assignment", False):
        # First verify that teacher heads are different (unless student_equal_heads is True)
        t0_equal_t1 = torch.allclose(teacher0_model.h.weight, teacher1_model.h.weight)
        print(f"Teacher heads verification - t0==t1: {t0_equal_t1}")
        
        # Then verify student head assignments
        h0_equal_t0 = torch.allclose(student_model.h0.weight, teacher0_model.h.weight)
        if config["student_equal_heads"]:
            h1_equal_t0 = torch.allclose(student_model.h1.weight, teacher0_model.h.weight)
            h0_equal_h1 = torch.allclose(student_model.h0.weight, student_model.h1.weight)
            print(f"Student heads verification - h0==t0: {h0_equal_t0}, h1==t0: {h1_equal_t0}, h0==h1: {h0_equal_h1}")
        else:
            h1_equal_t1 = torch.allclose(student_model.h1.weight, teacher1_model.h.weight)
            print(f"Student heads verification - h0==t0: {h0_equal_t0}, h1==t1: {h1_equal_t1}")
    
    # freeze heads (output layers)
    if config.get("full_ds", False):
        for param in full_model.h.parameters():
            param.requires_grad = False
    
    for param in teacher0_model.h.parameters():
        param.requires_grad = False
    
    for param in teacher1_model.h.parameters():
        param.requires_grad = False
    
    for param in student_model.h0.parameters():
        param.requires_grad = False
        
    for param in student_model.h1.parameters():
        param.requires_grad = False
    
    
    # assign pytorch_lightning models
    if config.get("full_ds", False):
        FullClassifier = ModelClassifier(model=full_model, lr=config["learning_rate"], momentum=config["momentum"])
    
    teacher0 = ModelClassifier(model=teacher0_model, lr=config["learning_rate"], momentum=config["momentum"])
    teacher1 = ModelClassifier(model=teacher1_model, lr=config["learning_rate"], momentum=config["momentum"])
    
    # define loggers
    logs_dir = join(save_dir, "logs")      # save dir for Tensorboard/CSV logs and checkpoints
    if not(exists(logs_dir)):
        makedirs(logs_dir)
        
    logger = TensorBoardLogger(save_dir=logs_dir, name="tensorboard")
    logger_csv = CSVLogger(save_dir=logs_dir, name="csv")
    
    # define full model trainer
    if config.get("full_ds", False):
        full_trainer = pl.Trainer(
            logger=[logger, logger_csv],
            accelerator="gpu",
            max_epochs= config["epochs_teacher"],
            deterministic=True, # problem for CNN adaptive pooling
            enable_progress_bar=not cluster,
            #detect_anomaly=config["detect_anomaly_flag"]
            )
    
    # define teacher trainers
    teacher0_trainer = pl.Trainer(
        logger=[logger, logger_csv],
        accelerator="gpu",
        max_epochs= config["epochs_teacher"],
        deterministic=True, # problem for CNN adaptive pooling
        enable_progress_bar=not cluster,
        #detect_anomaly=config["detect_anomaly_flag"]
        )
    
    teacher1_trainer = pl.Trainer(
        logger=[logger, logger_csv],
        accelerator="gpu",
        max_epochs= config["epochs_teacher"],
        deterministic=True, # problem for CNN adaptive pooling
        enable_progress_bar=not cluster,
        #detect_anomaly=config["detect_anomaly_flag"]
        )
    
    # train full model and teachers
    if config.get("full_ds", False):
        full_trainer.fit(FullClassifier, dm_full)
    
    teacher0_trainer.fit(teacher0, dm_task1)
    teacher1_trainer.fit(teacher1, dm_task2)
    
    # folder to store parameters
    param_dir = join(save_dir, "parameters")
    
    if not(exists(param_dir)):
        makedirs(param_dir)
    
    # save final solutions of full model and teachers
    if config.get("full_ds", False):
        torch.save(FullClassifier.model.state_dict(), join(param_dir, "full.pth"))
    
    torch.save(teacher0.model.state_dict(), join(param_dir, "teacher0.pth"))
    torch.save(teacher1.model.state_dict(), join(param_dir, "teacher1.pth"))    
    
    # train student 
    result_df = train_dual_heads(
        student = student_model, 
        teachers = [teacher0, teacher1], 
        datamodules = [dm_task1, dm_task2], 
        epochs_list = [config["epochs_t0"],config["epochs_t1"]],
        save_dir = join(save_dir, "parameters", "student"),
        lr = config["learning_rate"],
        momentum = config["momentum"],
        min_param_value = config.get("min_param_value", None)
        )
    
    result_df.to_csv(join(save_dir, mk_fname(filename="results", label=None, suffix="csv")))
