import os
from time import perf_counter
from typing import Any, Dict, List
import torch
import torchvision.models as models
from torch.utils.data import DataLoader
import pandas as pd
import copy
import torchvision
from tqdm import tqdm
from dataset.FairFaceDataset import FairFaceDataset
from models.base import ContinualLearning
from models.utils_model import CNNAttention, ContrastiveCNN, backbone
from dataset.CelebADataset import CelebADataset
from dataset.MTFLDataset import MTFLDataset
from dataset.PhysiQDataset import (
    PhysiQTorchDatasetBuilder,
    PhysiQDataset,
)
from sklearn.manifold import TSNE
from utilities.Logger import Logger
import matplotlib.pyplot as plt
from utilities.stats import track_system_stats
from utilities.utils import fix_seed, split_list

# from utilities.utils_transform import resnet_transform as default_transform
# from utilities.utils_transform import custom_transform
from utilities.utils_transform import get_resnet_transform, get_custom_transform
from argparser import get_args, get_models, get_models_MTL
from torchmetrics import CalibrationError


def load_model(args, device):
    # Function to modify the first convolutional layer of ResNet18
    def modify_resnet18_input(model, input_size):
        if input_size == 224:
            return  # Default size for ResNet18, no modification needed.

        # Adjust the first convolutional layer for different input sizes
        if input_size >= 64:
            kernel_size = 7
            stride = 2
            padding = 3
        elif input_size >= 32:
            kernel_size = 3
            stride = 1
            padding = 1
        else:
            kernel_size = 1
            stride = 1
            padding = 0

        model.conv1 = torch.nn.Conv2d(
            3, 64, kernel_size=kernel_size, stride=stride, padding=padding
        )

    # Initialize ResNet18 if the dataset is not 'physiq'
    if args["dataset"] != "physiq":
        if args["pretrain"]:
            ret_model = models.resnet18(
                weights=models.ResNet18_Weights.IMAGENET1K_V1
            ).to(device)
        else:
            ret_model = models.resnet18(weights=None).to(device)

        # Modify the ResNet18 model according to the input size
        modify_resnet18_input(ret_model, args["input_size"])

        ret_model = torch.nn.Sequential(
            *(list(ret_model.children())[:-1]),
            torch.nn.Flatten(),
            torch.nn.Linear(512, args["z_dim"]),
        ).to(device)

    # For the 'physiq' dataset, use the custom CNN model
    else:
        # ret_model = CNNAttention(
        #     num_timestep=args["input_size"],
        #     input_channels=6,
        #     cnn_hidden_dims=[16, 64, 128],
        #     num_attn_layers=1,
        #     attn_num_heads=8,
        #     latent_dim=args["z_dim"],
        #     dropout=0,
        # ).to(device)

        ret_model = ContrastiveCNN(num_timestep=args["input_size"],
                                   input_channels=6,
                                   hidden_dims=[16, 32, 64],
                                   latent_dim=args["z_dim"]).to(device)
        if args["pretrain"]:
            raise NotImplementedError(
                "Pretraining for PhysiQ dataset is not implemented yet due to latent dim conflict (models are hard coded to assume 512 dim)"
            )
            # ret_model.load_state_dict(torch.load('./models/cnn_attn_best_downstream.pt'))

    return ret_model


def get_transform(augment, input_size, model_name):
    if augment == "default":
        return get_resnet_transform(input_size, model_name)
    elif augment == "none":
        return torchvision.Compose([])
    else:
        raise ValueError("Augment not found")


def prepare_data(args, seed, resnet_transform, return_datasets=False):
    if args["dataset"] == "celeba":

        train_dataset = CelebADataset(
            split="train",
            transform=resnet_transform,
            seed=seed,
        )
        valid_dataset = CelebADataset(
            split="valid",
            transform=resnet_transform,
            seed=seed,
        )
        test_dataset = CelebADataset(
            split="test",
            transform=resnet_transform,
            seed=seed,
        )

    elif args["dataset"] == "mtfl":
        raise NotImplementedError("MTFL not used")
        train_dataset = MTFLDataset(
            split="train",
            transform=resnet_transform,
            seed=seed,
            all_binary=args["all_binary"],
        )
        valid_dataset = MTFLDataset(
            split="valid",
            transform=resnet_transform,
            seed=seed,
            all_binary=args["all_binary"],
        )

        test_dataset = MTFLDataset(
            split="test",
            transform=resnet_transform,
            all_binary=args["all_binary"],
        )
    elif args["dataset"] == "fairface":
        train_dataset = FairFaceDataset(
            split="train",
            transform=resnet_transform,
            seed=seed,
        )
        valid_dataset = FairFaceDataset(
            split="valid",
            transform=resnet_transform,
            seed=seed,
        )
        test_dataset = FairFaceDataset(
            split="test",
            transform=resnet_transform,
            seed=seed
        )
    elif args["dataset"] == "physiq":

        train_dataset = PhysiQDataset(
            split='train',
            seed=seed,
            all_binary=args["all_binary"],
        )
        valid_dataset = PhysiQDataset(
            split='valid',
            seed=seed,
            all_binary=args["all_binary"],
        )
        test_dataset = PhysiQDataset(
            split='test',
            seed=seed,
            all_binary=args["all_binary"],
        )
    else:
        raise Exception("invalid dataset name")

    prediction_targets = train_dataset.prediction_targets

    if args["train_subsample_ratio"] < 1:
        if args["job"] == "cl":
            # raise NotImplementedError("Subsampling not implemented for CL")
            Warning("Subsampling might not work for CL")
            train_dataset = train_dataset.random_split(
                args["train_subsample_ratio"]
            )
            valid_dataset = valid_dataset.random_split(
                args["train_subsample_ratio"]
            )
            test_dataset = test_dataset.random_split(
                args["train_subsample_ratio"]
            )
        else:
            train_dataset, _ = torch.utils.data.random_split(
                train_dataset,
                [
                    args["train_subsample_ratio"],
                    (1 - args["train_subsample_ratio"]),
                ],
            )
    if return_datasets:
        return train_dataset, valid_dataset, test_dataset, prediction_targets
    train_dataloader = DataLoader(
        train_dataset, batch_size=args["batch_size"], shuffle=True
    )
    valid_dataloader = DataLoader(
        valid_dataset, batch_size=args["batch_size"], shuffle=False
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=args["batch_size"], shuffle=False
    )

    return (
        train_dataloader,
        valid_dataloader,
        test_dataloader,
        prediction_targets,
    )


def train_one_task_cl(
        args,
        task_name,
        task_index,
        model: ContinualLearning,
        train_dataloader,
        valid_dataloader,
        seed,
        device,
        criterion,
        custom_transform,
        tracker=None,
):
    best_loss = float("inf")
    early_stopping_counter = 0
    time = 0

    for epoch in range(args["epochs"]):
        time_start = perf_counter()
        model.train()
        model.begin_task(train_dataloader, task_name, task_index, criterion=criterion)
        # total_loss = 0
        with tqdm(
                total=len(train_dataloader),
                desc=f"[Task {task_index}][Epoch {epoch + 1}/{args['epochs']}][Train]",
        ) as pbar:
            for sample in train_dataloader:
                image = sample["image"].to(device)
                cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
                tqdm_loss = model.compute_loss(
                    image,
                    cur_task_y,
                    image,
                    criterion,
                    custom_transform,
                    task_index,
                )
                pbar.set_postfix(Loss=tqdm_loss)
                pbar.update(1)
        time = (time * (epoch) + (perf_counter() - time_start)) / (epoch + 1)
        tracker()
        total_loss = 0
        model.eval()
        with tqdm(
                total=len(valid_dataloader),
                desc=f"[Task {task_index}][Epoch {epoch + 1}/{args['epochs']}][Valid]",
        ) as pbar:
            for sample in valid_dataloader:
                image = sample["image"].to(device)
                cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
                _, _, tqdm_loss = model.compute_loss_on_task_id(
                    image,
                    cur_task_y,
                    criterion,
                    task_index,
                    task_name=task_name,
                )
                total_loss += tqdm_loss.item()
                pbar.set_postfix(Loss=tqdm_loss.item())
                pbar.update(1)
                # print(loss)
        tqdm_loss = total_loss / len(valid_dataloader)
        total_loss += tqdm_loss
        if tqdm_loss < best_loss:
            best_loss = tqdm_loss
            save_model(model, args, seed, task_name)
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
        if early_stopping_counter >= args["early_stopping_tolerance"]:
            break
    model.end_task(train_dataloader, task_name, task_index)

    return total_loss / len(valid_dataloader), time


def train_cl(
        args: Dict[str, Any],
        device: torch.device,
        seeds: List[int],
        logger: Logger = None,
):
    results = []
    init_transform = get_transform(args["augment"], args["input_size"], args["model"])
    encoder_past = load_model(args, device)
    train_dataset_overall, valid_dataset, test_dataset, prediction_targets = (
        prepare_data(args, seeds[0], init_transform, return_datasets=True)
    )
    criterion = torch.nn.CrossEntropyLoss()
    test_dataloader = DataLoader(
        test_dataset, batch_size=args["batch_size"], shuffle=False
    )
    valid_dataloader = DataLoader(
        valid_dataset, batch_size=args["batch_size"], shuffle=False
    )
    custom_transform = get_custom_transform(
        input_size=args["input_size"], dataset_name=args["dataset"]
    )
    
    tsne_reducer = TSNE(n_components=2, random_state=325235, perplexity=20, max_iter=5000)
        
    with track_system_stats(logger) as system_tracker:
        for seed in seeds:
            torch.cuda.empty_cache()
            fix_seed(seed)
            model = get_models(
                args,
                encoder_past,
                cls_output_dim=args["cls_output_dim"],
                lr=args["lr"],
                input_size=args["input_size"],
                dataset_name=args["dataset"],
                buffer_size=args["buffer_size"],
                num_tasks=len(prediction_targets),
                z_dim=args["z_dim"],
                n_epochs=args["epochs"],
                enable_dynamic= not args["disable_dynamic"],
                dist_method=args["dist_method"],
                device = device
            ).to(device)
            for task_index, task_name in enumerate(prediction_targets):
                if args['split_task']:
                    train_dataset = train_dataset_overall.split_data_by_task(
                        task_index
                    )
                else:
                    train_dataset = train_dataset_overall
                train_dataloader = DataLoader(
                    train_dataset, batch_size=args["batch_size"], shuffle=True
                )

                total_loss, time = train_one_task_cl(
                    args,
                    task_name,
                    task_index,
                    model,
                    train_dataloader,
                    valid_dataloader,
                    seed,
                    device,
                    criterion,
                    custom_transform,
                    system_tracker,
                )
                # fisher_eigenvalues = compute_fisher_information(model, train_dataloader, criterion,
                #                                                     task_name, task_index, device)
                # print(fisher_eigenvalues)
                load_best_model(model, args, seed, task_name)
                temp_result, zs, ys = model.calculate_accuraciess(
                    test_dataloader,
                    prediction_targets[: task_index + 1],
                    device,
                )
                temp_result["time"] = time
                results.append(
                    generate_result_dict(
                        seed,
                        task_name,
                        total_loss,
                        train_dataloader,
                        temp_result,
                        prediction_targets[: task_index + 1],
                        task_index,
                    )
                )  # NOTE: outdated function, outputted into csv
                
                zs = tsne_reducer.fit_transform(zs)
                plt.scatter(zs[:, 0], zs[:, 1], c=ys)
                plt.tight_layout()
                # color code them based on labels
                plt.savefig(f"./figures/{seed}_tsne_{args['dataset']}_{args['model']}_task_{task_index}.png")
                system_tracker.print_stats()
                logger.results_append(temp_result)
    save_results(results, args, args["job"].upper(), len(prediction_targets))
    logger.write(len(prediction_targets))

# def train_ind(args: Dict[str, Any], device: torch.device, seeds: List[int], logger: Logger = None):
#     init_transform = get_transform(args["augment"], args["input_size"])
#     encoder_past = load_model(args, device)
#     train_dataset_overall, valid_dataset, test_dataset, prediction_targets = (
#         prepare_data(args, seeds[0], init_transform, return_datasets=True)
#     )
#     criterion = torch.nn.CrossEntropyLoss()
#     test_dataloader = DataLoader(
#         test_dataset, batch_size=args["batch_size"], shuffle=False
#     )
#     valid_dataloader = DataLoader(
#         valid_dataset, batch_size=args["batch_size"], shuffle=False
#     )
#     with track_system_stats(logger) as system_tracker:
#         for task_index, task_name in enumerate(prediction_targets):
#             torch.cuda.empty_cache()
#             for epochs in range(args["epochs"]):
#                 cur_encoder = copy.deepcopy(encoder_past).to(device)
#                 model = backbone(cur_encoder, cls_output_dim=2).to(device)
#                 # model = get_models(
                    
                

def train_mtl(
        args: Dict[str, Any],
        device: torch.device,
        seeds: List[int],
        logger: Logger = None,
):
    results = []
    init_transform = get_transform(args["augment"], args["input_size"], args["model"])
    encoder_past = load_model(args, device)
    criterion = torch.nn.CrossEntropyLoss()
    with track_system_stats(logger) as system_tracker:
        for seed in seeds:
            fix_seed(seed)
            (
                train_dataloader,
                valid_dataloader,
                test_dataloader,
                prediction_targets,
            ) = prepare_data(args, seed, init_transform)
            cur_encoder = copy.deepcopy(encoder_past).to(device)
            model = get_models_MTL(
                args,
                cur_encoder,
                tasks_name_to_cls_num={
                    task_name: 2 for task_name in prediction_targets
                },
                z_dim=args['z_dim'],
                cls_output_dim=args["cls_output_dim"],
            ).to(device)

            best_loss = float("inf")
            # TODO: earlystop
            early_stopping_counter = 0
            time = 0
            for epoch in range(args["epochs"]):
                train_total_loss, e_time = train_epoch_mtl(
                    model,
                    train_dataloader,
                    criterion,
                    device,
                    args,
                    epoch,
                    prediction_targets,
                    system_tracker,
                )
                valid_total_loss = validate_epoch_mtl(
                    model,
                    valid_dataloader,
                    criterion,
                    device,
                    args,
                    epoch,
                    prediction_targets,
                )
                time = (time * (epoch) + e_time) / (epoch + 1)
                if valid_total_loss < best_loss:
                    best_loss = valid_total_loss
                    save_model(model, args, seed)
                else:
                    early_stopping_counter += 1
                if early_stopping_counter >= args["early_stopping_tolerance"]:
                    break
            # INFERENCE For testing
            load_best_model(model, args, seed)
            temp_result = model.calculate_accuraciess(
                test_dataloader, prediction_targets, device
            )
            # save results:
            temp_result["time"] = time
            logger.results_append(temp_result)
            results.append(
                generate_result_dict_mtl(
                    seed,
                    epoch,
                    train_total_loss,
                    valid_total_loss,
                    temp_result,
                    prediction_targets,
                )
            )

    save_results(results, args, args["job"].upper(), len(prediction_targets))
    logger.write(len(prediction_targets))


def train_epoch_mtl(
        model,
        train_dataloader,
        criterion,
        device,
        args,
        epoch,
        prediction_targets,
        tracker=None,
):
    model.train()
    start_time = perf_counter()
    train_total_loss = 0
    with tqdm(
            total=len(train_dataloader),
            desc=f"[Epoch {epoch + 1}/{args['epochs']}][Train]",
    ) as pbar:
        for sample in train_dataloader:
            image = sample["image"].to(device)
            tasks_y = {
                task_name: sample[task_name].type(torch.long).to(device)
                for task_name in prediction_targets
            }
            train_loss = model.compute_loss(image, tasks_y, criterion)
            train_total_loss += train_loss.item()
            pbar.set_postfix(train_loss=train_loss.item())
            pbar.update(1)
    time = perf_counter() - start_time
    tracker()
    return train_total_loss / len(train_dataloader), time


def validate_epoch_mtl(
        model, valid_dataloader, criterion, device, args, epoch, prediction_targets
):
    model.eval()
    valid_total_loss = 0
    with torch.no_grad():
        with tqdm(
                total=len(valid_dataloader),
                desc=f"[Epoch {epoch + 1}/{args['epochs']}][Valid]",
        ) as pbar:
            for sample in valid_dataloader:
                image = sample["image"].to(device)
                tasks_y = {
                    task_name: sample[task_name].type(torch.long).to(device)
                    for task_name in prediction_targets
                }
                val_loss = model.compute_loss_nograd(image, tasks_y, criterion)
                valid_total_loss += val_loss.item()
                pbar.set_postfix(valid_loss=val_loss.item())
                pbar.update(1)
    return valid_total_loss / len(valid_dataloader)


def save_model(model: torch.nn.Module, args: dict, seed: int, task_name=None):
    base_path = "./saved_models"
    job_upper = args["job"].upper()
    dataset = args["dataset"]
    model_name = args["model"]
    directory = os.path.join(base_path, job_upper, dataset, model_name)

    if args["job"] == "cl":
        filename = f"{seed}_{task_name}_best_enc.pt"
    else:
        filename = f"{seed}_best_enc.pt"

    path = os.path.join(directory, filename)

    if not os.path.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            return

    try:
        torch.save(model.state_dict(), path)
    except Exception as e:
        print(f"Failed to save model: {e}")


def load_best_model(model, args, seed, task_name=None):
    path = f"./saved_models/{args['job'].upper()}/{args['dataset']}/{args['model']}/"
    if args["job"] == "cl":
        path += f"{seed}_{task_name}_best_enc.pt"
    else:
        path += f"{seed}_best_enc.pt"
    model.load_state_dict(torch.load(path, weights_only=False))


def generate_result_dict(
        seed,
        task_name,
        total_loss,
        train_dataloader,
        temp_result: dict,
        prediction_targets,
        task_index,
):
    cur_result = {
        "seed": seed,
        "task": task_name,
        "loss": total_loss / len(train_dataloader),
    }
    print(
        f'[Seed {seed}][Task {task_index}][{task_name}][Loss: {cur_result["loss"]: .2f}][',
        end=" ",
    )
    for each_task in prediction_targets:
        tm_ece = f"{each_task}_ece"
        print(
            f"{each_task}: {temp_result[each_task]:.2f}, {temp_result[tm_ece]:.5f}",
            end=" ",
        )
        cur_result[f"{each_task}_accuracy"] = temp_result[each_task]
        cur_result[f"{each_task}_ece"] = temp_result[tm_ece]
    print("]")
    return cur_result


def generate_result_dict_mtl(
        seed,
        epoch,
        train_total_loss,
        valid_total_loss,
        temp_result,
        prediction_targets,
):
    cur_result = {
        "seed": seed,
        "epoch": epoch,
        "train loss": train_total_loss,
        "valid loss": valid_total_loss,
    }
    print(f"[Seed {seed}][", end="")
    for each_task in prediction_targets:
        cur_result[f"{each_task}_accuracy"] = temp_result[each_task]
        print(f"{each_task}: {temp_result.get(each_task, None):.2f}", end=" ")
    print("]")
    return cur_result


def save_results(results, args, job_type, num_tasks):
    df = pd.DataFrame(results)
    # path = f"./experiment_results/{job_type.upper()}/{args['dataset']}/{args['model']}/"
    path = os.path.join(
        "./experiment_results", job_type.upper(), args["dataset"], args["model"]
    )
    if not os.path.exists(path):
        os.makedirs(path)
    dir = os.path.join(path, f"{num_tasks}tasks.csv")
    df.to_csv(dir, index=False)
