import time
import random
import math
import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
from decorator import decorator
from pathlib import Path
import yaml
import os

from data.custom_dataset import *

def get_ours_loaders(ds_la, ds_lu, ds_ua, ds_uu, batch_size, args):
    train_unlabeled = ConcatDataset([ds_la, ds_lu, ds_ua, ds_uu])
    train_labeled = ConcatDataset([ds_la, ds_lu])
    
    unlabeled_loader = DataLoader(
        train_unlabeled,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    labeled_loader = DataLoader(
        train_labeled,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )    

    return unlabeled_loader, labeled_loader

def get_ours_test_loaders(dict_tla, dict_tlu, batch_size, args):
    loaders = []
    for i in range(8):
        ds_test = ConcatDataset([dict_tla[i], dict_tlu[i]])
        loader_i = DataLoader(ds_test, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        loaders.append(loader_i)
        
    return loaders

    
def log_pattern_counts_to_wandb(dataset, name):
    count_trues = {}  # pattern => list of idx
    for i in range(len(dataset)):
        _, _, mask = dataset[i]
        num_true = torch.sum(mask).item()
        
        if num_true not in count_trues:
            count_trues[num_true] = 0
        count_trues[num_true] += 1
            
    data = []
    for num_true, cnt in count_trues.items():
        data.append([num_true, cnt])

    table = wandb.Table(data=data, columns=["num_true", "count"])

    pattern_bar = wandb.plot.bar(
        table,  
        "num_true",  
        "count",    
        title=f"{name} Mask Pattern Distribution by True Count"
    )

    wandb.log({f"{name}_pattern_distribution_by_num_true": pattern_bar})


def build_cache_file_name(args, prefix="train"):
    if prefix == "train":
        filename = f"{prefix}_{args.task_name}_nb{args.num_clients}_{args.missing_type}"
        if args.missing_type == "mcar":
            filename += f"_p{int(args.p_mcar*100)}"
        elif args.missing_type == "mar":
            filename += f"{args.option}"
        elif args.missing_type == "mnar":
            filename += f"{int(args.p_miss*100)}"

        if hasattr(args,'labeled_aligned_num'):
            filename += f"_la{args.labeled_aligned_num}"
        if hasattr(args,'labeled_unaligned_num'):
            filename += f"_lu{args.labeled_unaligned_num}"
        filename += ".pt"
    else:
        filename = f"{prefix}_{args.task_name}_nb{args.num_clients}.pt"
    return filename

    
def _create_train_in_memory(args):
    dataset = args.task_name
    if dataset == 'modelnet10':
        root_dir = Path(__file__).absolute().parent / 'dataset' / 'modelnet10_aligned'
        mean, std = compute_mean_std_for_modelnet10(root_dir, split='train', num_views=12, max_samples=3000)
        print(f"[modelnet10] computed mean={mean}, std={std}")

        torch.save({'mean':mean, 'std':std}, f'{root_dir.parent}/modelnet10_stats.pt')
    
        transform_final = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor(),
            transforms.Normalize(mean.tolist(), std.tolist())
        ])
    
        base_ds = ModelNet10MultiViewDataset(
            root_dir=root_dir,
            split='train',
            transform=transform_final,
            num_repeat=6
        )
    elif dataset == 'hapt':
        root_dir = Path(__file__).absolute().parent / 'dataset' / 'hapt' / 'Train'

        base_ds = HAPTDataset(
            X_file=f'{root_dir}/X_train.txt',
            y_file=f'{root_dir}/y_train.txt',
            num_clients=args.num_clients
        )

        torch.save({'mean':base_ds.mean, 'std':base_ds.std}, f'{root_dir.parent.parent}/hapt_stats.pt')

    elif dataset == 'isolet':
        root_dir = Path(__file__).absolute().parent / 'dataset' / 'isolet'

        base_ds = IsoletDataset(
            data_file=f'{root_dir}/isolet1+2+3+4.data',
            num_clients=args.num_clients
        )

        torch.save({'mean':base_ds.mean, 'std':base_ds.std}, f'{root_dir.parent}/isolet_stats.pt')
        
    else:    
        data_dir = Path(__file__).absolute().parent / 'dataset' / dataset

        if dataset == 'fashionmnist':
            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))])
        datasets_dict = {'fashionmnist': datasets.FashionMNIST}
        DatasetClass = datasets_dict[dataset]
        base_ds=DatasetClass(root=data_dir,train=True,download=True,transform=transform)
    
    n=len(base_ds)
    indices = list(range(n))
    random.shuffle(indices)

    mask_array = torch.ones((n, args.num_clients), dtype=torch.bool)

    leftover_idx = []

    la_num = args.labeled_aligned_num
        
    pre_la_idx = indices[:la_num]
    leftover_idx = indices[la_num:]

    leftover_sub = Subset(base_ds, leftover_idx)
    
    if args.missing_type == 'mcar':
        leftover_ds = MissingAppliedDataset_sample(leftover_sub, args.task_name, args.num_clients, 'mcar', p_mcar=args.p_mcar)
    elif args.missing_type == 'mar':
        leftover_ds = MissingAppliedDataset_sample(leftover_sub, args.task_name, args.num_clients, 'mar', p_mcar=None, p_miss=None, option=args.option, init_cut_var=0.1 * (args.num_clients-1))
    elif args.missing_type == 'mnar':
        leftover_ds = MissingAppliedDataset_sample(leftover_sub, args.task_name, args.num_clients, 'mnar', p_mcar=None, p_miss=args.p_miss)
    else:
        raise ValueError(f"Unknown missing_type: {args.missing_type}")
        

    log_pattern_counts_to_wandb(leftover_ds, 'Train')
    for i in range(len(leftover_ds)):
        real_idx = leftover_idx[i]  
        _,_,msk = leftover_ds[i]
        mask_array[real_idx] = msk

    # leftover_count
    leftover_count = len(leftover_idx)
    l_unaligned_num  = args.labeled_unaligned_num
    unlabeled_num  = leftover_count - l_unaligned_num

    # slice leftover
    leftover_lu_idx = leftover_idx[:l_unaligned_num]
    leftover_u_idx = leftover_idx[l_unaligned_num : ]


    # final sets
    la_indices = set(pre_la_idx)  # already full True
    lu_indices = set()
    ua_indices = set()
    uu_indices = set()
        
    # leftover_lu_idx => labeled_unaligned by ratio, but if mask.all => => labeled_aligned
    count = 0
    for idx in leftover_lu_idx:
        if torch.all(mask_array[idx]):
            la_indices.add(idx)
        else:
            if mask_array[idx][0]:
                lu_indices.add(idx)
            else:
                uu_indices.add(idx)
                count += 1

    # leftover_u_idx => unlabeled by ratio, if mask !all => unlabeled_unaligned
    for idx in leftover_u_idx:
        if torch.all(mask_array[idx]):
            ua_indices.add(idx)
        else:
            if count > 0:    
                if mask_array[idx][0]:
                    lu_indices.add(idx)
                    count -= 1
                    continue
            uu_indices.add(idx)

    # build final dataset with mask_array
    final_train_ds = FinalMissingDataset(base_ds, mask_array)

    ds_la = Subset(final_train_ds, sorted(la_indices))
    ds_lu = Subset(final_train_ds, sorted(lu_indices))
    ds_ua = Subset(final_train_ds, sorted(ua_indices))
    ds_uu = Subset(final_train_ds, sorted(uu_indices))
    return ds_la, ds_lu, ds_ua, ds_uu

def create_train_datasets_with_cache(args):
    os.makedirs("./cached_data",exist_ok=True)
    fname=build_cache_file_name(args,prefix="train")
    path=os.path.join("./cached_data",fname)
    if os.path.exists(path):
        print(f"[create_train_datasets_with_cache] load from {path}")
        data=torch.load(path)
        ds_la=data['ds_la']
        ds_lu=data['ds_lu']
        ds_ua=data['ds_ua']
        ds_uu=data['ds_uu']
        return ds_la,ds_lu,ds_ua,ds_uu
    else:
        print(f"[create_train_datasets_with_cache] generating => {path}")
        ds_la,ds_lu,ds_ua,ds_uu=_create_train_in_memory(args)
        torch.save({
            'ds_la':ds_la,'ds_lu':ds_lu,
            'ds_ua':ds_ua,'ds_uu':ds_uu
        },path)
        return ds_la,ds_lu,ds_ua,ds_uu

        
def create_test_config(args, config_idx):
    dataset = args.task_name
    if dataset == 'modelnet10':
        root_dir = Path(__file__).absolute().parent / 'dataset' / 'modelnet10_aligned'

        stats = torch.load(f'{root_dir.parent}/modelnet10_stats.pt')
    
        transform_final = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor(),
            transforms.Normalize(stats['mean'].tolist(), stats['std'].tolist())
        ])
    
        base_test = ModelNet10MultiViewDataset(
            root_dir=str(root_dir),
            split='test',
            transform=transform_final,
            num_repeat=2
        )

    elif dataset == 'hapt':
        root_dir = Path(__file__).absolute().parent / 'dataset' / 'hapt' / 'Test'
        stats = torch.load(f'{root_dir.parent.parent}/hapt_stats.pt')

        base_test = HAPTDataset(
            X_file=f'{root_dir}/X_test.txt',
            y_file=f'{root_dir}/y_test.txt',
            num_clients=args.num_clients,
            feature_indices = [stats['mean'], stats['std']]
        )

    elif dataset == 'isolet':
        root_dir = Path(__file__).absolute().parent / 'dataset' / 'isolet'
        stats = torch.load(f'{root_dir.parent}/isolet_stats.pt')

        base_test = IsoletDataset(
            data_file=f'{root_dir}/isolet5.data',
            num_clients=args.num_clients,
            feature_indices = [stats['mean'], stats['std']]
        )

    else:
        data_dir = Path(__file__).absolute().parent / 'dataset' / dataset
    
        if dataset == 'fashionmnist':
            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))])
        datasets_dict = {'fashionmnist': datasets.FashionMNIST}
        DatasetClass = datasets_dict[dataset]
        base_test=DatasetClass(root=data_dir,train=False,download=True,transform=transform)
        
    if config_idx < 4:
        p_m = [0.0, 0.2, 0.5, 0.8][config_idx]
        full_test= MissingAppliedDataset_sample(base_test, args.task_name, args.num_clients, 'mcar', p_mcar=p_m)

    elif config_idx < 6:
        option=config_idx - 4
        init_cut_var = 0.1 * (args.num_clients-1)
        p_miss = 0.8
        full_test= MissingAppliedDataset_sample(base_test, args.task_name, args.num_clients, 'mar', p_mcar=None, p_miss=None, option=option, init_cut_var=init_cut_var)

    else:
        p_m = [0.7, 0.9][config_idx-6]
        full_test = MissingAppliedDataset_sample(base_test, args.task_name, args.num_clients, 'mnar', p_mcar=None, p_miss=p_m)

    log_pattern_counts_to_wandb(full_test, f'Test{config_idx}')
    n=len(full_test)
    tla_idx=[]
    tlu_idx=[]
    for i in range(n):
        x,y,mask=full_test[i]
        if mask.all():
            tla_idx.append(i)
        else:
            tlu_idx.append(i)

    ds_tla=Subset(full_test,tla_idx)
    ds_tlu=Subset(full_test,tlu_idx)
    
    return ds_tla,ds_tlu
        
def create_test_configs_with_cache(args):
    os.makedirs("./cached_data",exist_ok=True)
    fname=build_cache_file_name(args,prefix="test")
    path=os.path.join("./cached_data",fname)
    if os.path.exists(path):
        print(f"[create_test_configs_with_cache] load from {path}")
        data = torch.load(path)
        ds_tla_dict = data['ds_tla_dict']
        ds_tlu_dict = data['ds_tlu_dict']

    else:
        print(f"[create_test_configs_with_cache] generating => {path}")
        ds_tla_dict = {}
        ds_tlu_dict = {}
        for i in range(8):
            ds_tla_i, ds_tlu_i = create_test_config(args, i)
            ds_tla_dict[i] = ds_tla_i
            ds_tlu_dict[i] = ds_tlu_i
        torch.save({'ds_tla_dict': ds_tla_dict, 'ds_tlu_dict': ds_tlu_dict}, path)

    return ds_tla_dict, ds_tlu_dict

    
def print_exp_info(args, config, epoch):
    s = f'epoch:[{epoch + 1}/{config["num_epochs"]}]  {args.device}  {args.task_name}  method:{args.method}  K:{args.num_clients}  labeled_aligned_num:{args.labeled_aligned_num}   labeled_unaligned_num:{args.labeled_unaligned_num}   missing_type:{args.missing_type}  p_mcar:{args.p_mcar}  p_miss:{args.p_miss}  option:{args.option}  seed:{args.seed}  lr:{config["lr"]}  decay:{config["weight_decay"]}  mom:{config["momentum"]}'
    print(f"{len(s) * '-'}\n{s}\n{len(s) * '-'}")

def init_wandb(args, config):
    wandb_config = {
            'dataset': config["dataset"],
            'architecture': config["model"],
            'cuda': args.cuda_id,
            'method': args.method,
            'lr': config["lr"],
            'n_epochs': config["num_epochs"],
            'seed': args.seed,
            'num_clients': args.num_clients,
            'batch_size': config["batch_size"],
            'weight_decay': config["weight_decay"],
            'momentum': config["momentum"],
            }

    name = args.wandb_name if args.wandb_name is not None else f'{args.task_name}_{args.method}_K{args.num_clients}_p_miss_train{args.p_miss_train}_s{args.seed}'

    wandb.init(project=args.project, config=wandb_config, name=name)

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

@decorator
def time_decorator(func, *args, **kwargs):
    start = time.time()
    result = func(*args, **kwargs)
    elapsed_time = time.time() - start
    print(f"time to run {func.__name__}(): {elapsed_time:.2f}s.")
    return result

@time_decorator
def test_ours(dataloader, models, criterion, args, is_final=False, is_train_data=False):
    [model] = models
    device = args.device
    model.eval()
    num_samples = len(dataloader.dataset)
    num_batches = len(dataloader)

    neg_bound_list = []
    correct = 0.0
    with torch.no_grad():
        for batch in dataloader:
            *inputs, labels, masks = batch
            invalid_idx = []
            for i in range(masks.shape[0]):
                if torch.sum(masks[i]).item() == 0:
                    num_samples -= 1
                    invalid_idx.append(i)

            if len(invalid_idx) == masks.shape[0]:
                num_batches -= 1
                continue

            valid_idx = list(set(range(masks.shape[0])) - set(invalid_idx))
            bs = len(valid_idx)
            
            new_inputs = [tensor[valid_idx].to(device) for tensor in inputs]
            new_labels, new_masks = labels[valid_idx].to(device), masks[valid_idx].to(device)

            batch_bound, batch_correct = model(new_inputs, new_labels, new_masks.T, args, bs, training=False)
            
            neg_bound_list.append(batch_bound.detach().cpu())
            correct += batch_correct

    neg_bound_sum = torch.stack(neg_bound_list).sum().item()

    neg_bound_avg = -math.log(args.K_test) - (neg_bound_sum / num_samples)

    # build metrics
    data_split_type = "train" if is_train_data else "test"
    metrics = {}
    if is_final:
        metrics[f"final_{data_split_type}_loss"] = neg_bound_avg
        metrics[f"final_{data_split_type}_acc"] = 100 * correct / num_samples
    else:
        metrics[f"{data_split_type}_loss"] = neg_bound_avg
        metrics[f"{data_split_type}_acc"]  = 100 * correct / num_samples

    return metrics

@time_decorator
def train_ours(dataloader, models, optimizers, criterion, args):
    [model], [optimizer] = models, optimizers

    device = args.device

    model.train()
    num_samples = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    neg_bound_list = []
    loss = 0.0
    for batch_num, batch in enumerate(dataloader):
        *inputs, labels, masks = batch  # 'labels'=targets
        invalid_idx = []
        for i in range(masks.shape[0]):
            if torch.sum(masks[i]).item() == 0:
                num_samples -= 1
                invalid_idx.append(i)

        if len(invalid_idx) == masks.shape[0]:
            num_batches -= 1
            continue

        valid_idx = list(set(range(masks.shape[0])) - set(invalid_idx))
        bs = len(valid_idx) 
        
        new_inputs = [tensor[valid_idx].to(device) for tensor in inputs] 
        new_labels, new_masks = labels[valid_idx].to(device), masks[valid_idx].to(device)

        optimizer.zero_grad()

        batch_bound, _ = model(new_inputs, new_labels, new_masks.T, args, bs)
        
        loss = batch_bound / float(bs)  
        loss.backward()

        max_norm = 1.0
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        optimizer.step()

        neg_bound_list.append(batch_bound.detach().cpu())
        if (batch_num + 1) % 25 == 0:
            print(f"\tBatch [{batch_num + 1}/{len(dataloader)}] train loss (last): {loss.item():.4f}")

    neg_bound_sum = torch.stack(neg_bound_list).sum().item()
    neg_bound_avg = -math.log(args.K) - (neg_bound_sum / num_samples)

    metrics = {"train_loss": neg_bound_avg, "train_acc": 0.0}
    return metrics


def setup_task(args):
    from models import get_model
    from optimizers import get_optimizer
    from criterions import get_criterion
    from schedulers import get_scheduler
    
    with open("configs/task_config.yaml", "r") as file:
        configurations = yaml.safe_load(file)

    config = configurations[args.method][args.task_name][args.num_clients][args.p_mcar]

    model = get_model(args.method, config["model"], config["dataset"], args, config)
    optimizer = get_optimizer(args.method, config["optimizer"], model, config)
    scheduler = get_scheduler(args.method, config["scheduler"], optimizer, config)
    criterion = get_criterion(config["criterion"])

    if args.method in ["ours", "ours_mnar"]:
        train, test = train_ours, test_ours
    
    return config, model, optimizer, scheduler, criterion, train, test
