import datetime
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import wandb
import time
import copy
import random
import sys
import warnings
import h5py
import distill_utils
from tqdm import tqdm
from functools import partial


from prism_reparam_module import ReparamModule
from glob import glob
from torchvision.utils import save_image
from tqdm import tqdm, trange
from utils import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug, Conv3DNet, MultiStaticSharedDataset

warnings.filterwarnings("ignore", category=DeprecationWarning)

def scaled_grad_hook(grad, scale):
    return grad * scale


def exposure_counts(idxes, T=16):
    counts = [0.0] * len(idxes)
    for k in range(T):
        if k in idxes:
            counts[idxes.index(k)] += 1.0
        else:
            l = max(i for i in idxes if i < k)
            r = min(i for i in idxes if i > k)
            gap = r - l
            counts[idxes.index(l)] += (r - k) / gap
            counts[idxes.index(r)] += (k - l) / gap
    return torch.tensor(counts)             # shape  [K]


def compute_temporal_correlations(grad_tensor, idxes):
    reshaped_grads = grad_tensor.flatten(2)  # Shape: (B, 16, -1)
    norms = torch.norm(reshaped_grads, dim=2, keepdim=True) + 1e-8
    normalized_grads = reshaped_grads / norms
    pivot_correlation = {idx: torch.bmm(normalized_grads, normalized_grads[:, idx, :].unsqueeze(2)).squeeze(-1) for idx in idxes}
    return pivot_correlation

def compute_temporal_correlations_l2(grad_tensor, idxes):
    reshaped_grads = grad_tensor.flatten(2)  # Shape: (B, 16, -1)
    pivot_distance = {idx: torch.norm(reshaped_grads - reshaped_grads[:, idx, :].unsqueeze(1), dim=2) for idx in idxes}
    # max_values = {k: v.max().item() for k, v in pivot_distance.items()}
    # print(max_values)
    return pivot_distance   

def analyze_gradient_differences(grad_tensor, idxes):
    if args.l2_based_thresh:
        pivot_correlation = compute_temporal_correlations_l2(grad_tensor, idxes)
    else:
        pivot_correlation = compute_temporal_correlations(grad_tensor, idxes)
    return pivot_correlation


def interpolate_frames(input_tensor, idxes, target_frames=16):
    if len(input_tensor.shape) == 4:
        input_tensor = input_tensor.unsqueeze(0)
    n, x, c, h, w = input_tensor.shape
    target_frames = 8 if args.dataset == 'ssv2' or args.dataset == 'k400' else 16

    frame_list = [None] * target_frames
    idx_map = {idx: pos for pos, idx in enumerate(idxes)}

    for i in range(target_frames):
        if i in idx_map:
            frame_list[i] = input_tensor[:, idx_map[i]]
        else:
            left_idx = max(idx for idx in idx_map if idx < i)
            right_idx = min(idx for idx in idx_map if idx > i)
            gap = right_idx - left_idx
            right_weight = (i - left_idx) / gap
            left_weight = 1 - right_weight
            left_orig_idx = idx_map[left_idx]
            right_orig_idx = idx_map[right_idx]
            frame_list[i] = (left_weight * input_tensor[:, left_orig_idx] + right_weight * input_tensor[:, right_orig_idx])
    output_tensor = torch.stack(frame_list, dim=1)
    return output_tensor


def load_from_h5():
    print('Start loading data from saved!')
    if not os.path.exists(f'/local_data/video_h5_files/{args.dataset}_data_train.h5') or not os.path.exists(f'/local_data/video_h5_files/{args.dataset}_data_test.h5'):
        print('Start loading data from saved!')
        video_all = torch.load(f'/local_data/video_h5_files/pt_files/{args.dataset}_video.pt')
        label_all = torch.load(f'/local_data/video_h5_files/pt_files/{args.dataset}_label.pt')

        os.makedirs('video_h5_files', exist_ok=True)
        h5_file_path = f'/local_data/video_h5_files/{args.dataset}_data_train.h5'
        with h5py.File(h5_file_path, 'w') as f:
            f.create_dataset('video_all', data=video_all.cpu().numpy())
            f.create_dataset('label_all', data=label_all.cpu().numpy())
        print('Saved in h5 format')

    start = time.time()
    with h5py.File(f'/local_data/video_h5_files/{args.dataset}_data_train.h5', 'r') as f:
        video_all = torch.tensor(f['video_all'][:])
        label_all = torch.tensor(f['label_all'][:], dtype=torch.long)
    print("Time to load data: ", time.time()-start)

    dst_train = torch.utils.data.TensorDataset(video_all, label_all)
    return dst_train, label_all


def main(args):
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    channel, im_size, num_classes, class_names, mean, std, _, _, _ = get_dataset(args.dataset, args.data_path)
    dst_train, label_all = load_from_h5()

    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
    accs_all_exps = dict()
    for key in model_eval_pool:
        accs_all_exps[key] = []
    data_save = []
    project_name = "PRISM_{}".format(args.method)
    wandb.init(sync_tensorboard=False,
               project=project_name,
               job_type="CleanRepo",
               config=args,
               name = f'PRISM_{args.dataset}_ipc{args.vpc}_{args.lr}_{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}' #TODO Maybe change the name?
               )
    args = type('', (), {})()

    for key in wandb.config._items:
        setattr(args, key, wandb.config._items[key])
    if args.batch_syn is None:
        args.batch_syn = num_classes *args.vpc 

    args.distributed = torch.cuda.device_count() > 1

    print('Hyper-parameters: \n', args.__dict__)
    print('Evaluation model pool: ', model_eval_pool)

    ''' organize the real dataset '''
    indices_class = [[] for c in range(num_classes)]

    print("BUILDING DATASET")
    for i, lab in tqdm(enumerate(label_all)):
        indices_class[lab].append(i)
    label_all = torch.tensor(label_all, dtype=torch.long, device="cpu")
    num_max_frames = 8 if args.dataset == 'ssv2' or args.dataset == 'k400' else 16

    def get_images(c, n):
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        if n == 1:
            imgs = dst_train[idx_shuffle[0]][0].unsqueeze(0)
        else:
            imgs = torch.cat([dst_train[i][0].unsqueeze(0) for i in idx_shuffle], 0)
        return imgs.to(args.device)

    warmup_phase_iteration = int((args.Iteration + 1) * 0.2)
    cooldown_phase_start_iteration = args.Iteration + 1 - warmup_phase_iteration
    if args.no_cooldown:
        cooldown_phase_start_iteration = args.Iteration + 1
    if args.no_warmup:
        warmup_phase_iteration = 0
    eval_it_pool = np.arange(warmup_phase_iteration, args.Iteration + 1, args.eval_it).tolist()
    eval_it_pool.append(args.Iteration)
    print('Evaluation iterations: ', eval_it_pool)

    start_it = 0
    if args.lr == int(args.lr):
        args.lr = int(args.lr)
    frame_wise_syn_dict = {}
    optimizer_frame_wise_dict = {}
    frame_wise_idxes_dict = {}
    grad_hook = {}
    frames_added_per_class = {}

    for c in range(num_classes):
        frame_wise_syn_dict[c] = torch.randn(size=(args.vpc, args.starting_frame_num, 3, im_size[0], im_size[1]), dtype=torch.float)
        frame_wise_syn_dict[c] = frame_wise_syn_dict[c].detach().to(args.device).requires_grad_(True)
        frame_wise_idxes_dict[c] = np.linspace(0, num_max_frames-1, args.starting_frame_num, dtype=int).tolist()
        with torch.no_grad():
            cnt = exposure_counts(frame_wise_idxes_dict[c], T=num_max_frames).to(args.device)
            grad_scale = (cnt.mean() / (cnt + 1e-6)).view(1, -1, 1, 1, 1)
        grad_hook[c] = frame_wise_syn_dict[c].register_hook(partial(scaled_grad_hook, scale=grad_scale))
        optimizer_frame_wise_dict[c] = torch.optim.SGD([frame_wise_syn_dict[c]], lr=args.lr, momentum=0.95)

    syn_lr = torch.tensor(args.lr_teacher)
    syn_lr = syn_lr.detach().to(args.device).requires_grad_(False)

    criterion = nn.CrossEntropyLoss().to(args.device)
    print('%s training begins' % get_time())

    best_acc = {m: 0 for m in model_eval_pool}
    best_std = {m: 0 for m in model_eval_pool}

    # --- Get labels for the synthetic data
    label = torch.tensor(np.stack([np.ones(args.vpc)*i for i in range(0, num_classes)]), dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

    save_dir = os.path.join(args.save_path, project_name, wandb.run.name)
    os.makedirs(save_dir, exist_ok=True)
    _log = open(os.path.join(save_dir, 'log.txt'), 'w')

    start_time = time.time()

    for it in tqdm(range(start_it, args.Iteration + 1), desc=f'{args.method} {args.dataset} {args.vpc} VPC {args.lr} LR', bar_format='{desc:<25} |{bar}| {percentage:3.0f}% [{n_fmt}/{total_fmt}] [{elapsed}<{remaining}, {rate_fmt}]'):
        wandb.log({"Progress": it}, step=it)
        if it in eval_it_pool or it == warmup_phase_iteration:
            # --- Every it, save the synthetic data (evaluation on a separate python file)
            with torch.no_grad():
                save_dict = {
                    'frame_wise_syn_dict': frame_wise_syn_dict,
                    'optimizer_frame_wise_dict': optimizer_frame_wise_dict,
                    'frame_wise_idxes_dict': frame_wise_idxes_dict,
                    'grad_hook': grad_hook,
                    'frames_added_per_class': frames_added_per_class,
                    'it': it,
                }

                torch.save(save_dict, os.path.join(save_dir, f"checkpoint_{it}.pt"))

        net = get_network(args.model, channel, num_classes, im_size).to(args.device)  # get a random model
        net.train()
        for param in list(net.parameters()):
            param.requires_grad = False
        embed = net.module.embed if args.distributed else net.embed
        loss_avg = 0
        loss = torch.tensor(0.0).to(args.device)
        increased = False
        frames_added_per_class = {}

        for c in range(0, num_classes):
            if it == 0 or it == cooldown_phase_start_iteration:
                print('\nWarmUp / CoolDown Phase\n')
            img_real = get_images(c, args.batch_real)
            img_syn = interpolate_frames(frame_wise_syn_dict[c], frame_wise_idxes_dict[c])

            output_real = embed(img_real).detach()
            output_syn = embed(img_syn)

            class_loss = torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0)) ** 2)

            wandb.log({f"Class {c} Loss": class_loss.item()}, step=it)
            loss += class_loss

            optimizer_frame_wise_dict[c].zero_grad()
            if img_syn.grad is not None:
                img_syn.grad.zero_()
            img_syn.retain_grad()
            class_loss.backward()
            optimizer_frame_wise_dict[c].step()

            if it >= warmup_phase_iteration and it < cooldown_phase_start_iteration:
                if it == warmup_phase_iteration:
                    print('\nAdding Frames\n')
                if not(args.do_not_increase_frame) and len(frame_wise_idxes_dict[c]) < num_max_frames:
                    tmp_grad = img_syn.grad
                    results = analyze_gradient_differences(tmp_grad, frame_wise_idxes_dict[c])
                    added_idxes = []
                    for t in range(num_max_frames):
                        if t in frame_wise_idxes_dict[c]:
                            continue
                        lefts  = [p for p in frame_wise_idxes_dict[c] if p < t]
                        rights = [p for p in frame_wise_idxes_dict[c] if p > t]
                        if not lefts or not rights:
                            continue
                        left  = max(lefts)
                        right = min(rights)
                        cos_left  = results[left][0, t]
                        cos_right = results[right][0, t]
                        if args.l2_based_thresh:
                            eps = 0.141
                            if cos_left > eps and cos_right > eps:
                                added_idxes.append(t)
                        else:
                            eps = 0.0
                            if cos_left < -eps and cos_right < -eps:
                                if args.random_addition:
                                    added_idxes.append(random.randrange(left, right-1))
                                else:
                                    added_idxes.append(t)
                    if added_idxes:
                        old_set = set(frame_wise_idxes_dict[c])
                        new_set = sorted(list(old_set.union(added_idxes)))
                        if new_set != frame_wise_idxes_dict[c]:
                            frame_wise_idxes_dict[c] = new_set
                            increased = True
                            _temp, _temp_buffer = [], []
                            original_added = 0
                            for idx in frame_wise_idxes_dict[c]:
                                if idx not in old_set: 
                                    if args.add_frame_init == 'noise':
                                        _temp.append(torch.randn(size=(args.vpc, 3, im_size[0], im_size[1]), dtype=torch.float).to(args.device))
                                    elif args.add_frame_init == 'interpolated':
                                        _temp.append(img_syn[:, idx, :])
                                    _temp_buffer.append(torch.zeros_like(img_syn[:, idx, :]))
                                else:
                                    _temp.append(frame_wise_syn_dict[c][:, original_added, :])
                                    for _, value in optimizer_frame_wise_dict[c].state.items():
                                        _temp_buffer.append(args.momentum_reduction*value['momentum_buffer'][:, original_added, :])
                                    original_added += 1

                            frame_wise_syn_dict[c] = torch.stack(_temp, dim=1).detach().to(args.device).requires_grad_(True)

                            buffer = torch.stack(_temp_buffer, dim=1).detach().to(args.device)
                            grad_hook[c].remove()
                            with torch.no_grad():
                                cnt = exposure_counts(frame_wise_idxes_dict[c], T=num_max_frames).to(args.device)
                                grad_scale = torch.clamp((cnt.mean() / (cnt + 1e-6)), 0.5, 1.5).view(1, -1, 1, 1, 1)
                            grad_hook[c] = frame_wise_syn_dict[c].register_hook(partial(scaled_grad_hook, scale=grad_scale))

                            optimizer_frame_wise_dict[c] = torch.optim.SGD([frame_wise_syn_dict[c]], lr=args.lr, momentum=0.95)
                            optimizer_frame_wise_dict[c].state[frame_wise_syn_dict[c]]['momentum_buffer'] = buffer

                            added_count = len(frame_wise_idxes_dict[c]) - len(old_set)
                            frames_added_per_class[c] = added_count 

        loss_avg += loss.item()

        loss_avg /= (num_classes)

        wandb.log({"Loss": loss_avg}, step=it)

        if  it % 100 == 0 or increased:
            _log.write(f"Adding Frames Iteration {it}/{args.Iteration} | Dataset: {args.dataset} | VPC: {args.vpc} | LR: {args.lr} | Loss: {loss_avg:.6f} \n")
            cnt_added_frames = 0
            for c in frames_added_per_class:
                _log.write(f"[Adding] Class {c} | Added {frames_added_per_class[c]} frames. \n")
                cnt_added_frames += frames_added_per_class[c]
            wandb.log({"Added": cnt_added_frames}, step=it)

        if loss_avg > 100:
            _log.write(f"Loss is too high. Exiting...\n")
            _log.close()
            exit("Loss is too high. Exiting...")

    _log.close()
    wandb.finish()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='miniUCF101', help='dataset', choices=['hmdb', 'miniucf', 'ssv2', 'k400'])

    parser.add_argument('--method', type=str, default='DM', help='MTT or DC or DM')
    parser.add_argument('--model', type=str, default='ConvNet3D', help='model')

    parser.add_argument('--vpc', type=int, default=1, help='', choices=[1, 5, 10])

    parser.add_argument('--eval_mode', type=str, default='SS',
                        help='eval_mostartItde, check utils.py for more info')

    parser.add_argument('--num_eval', type=int, default=3, help='how many networks to evaluate on')

    parser.add_argument('--eval_it', type=int, default=500, help='how often to evaluate')

    parser.add_argument('--epoch_eval_train', type=int, default=500,
                        help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=5000, help='how many distillation steps to perform')

    parser.add_argument('--lr', type=float, default=0.01, help='learning rate for updating synthetic dynamic memory')

    parser.add_argument('--lr_lr', type=float, default=1e-05, help='learning rate for updating... learning rate')
    parser.add_argument('--lr_teacher', type=float, default=0.01, help='initialization for synthetic learning rate')

    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_syn', type=int, default=None, help='should only use this if you run out of VRAM')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')

    parser.add_argument('--data_path', type=str, default='/local_data/', help='dataset path')
    parser.add_argument('--buffer_path', type=str, default='/local_data/buffer', help='buffer path')

    parser.add_argument('--expert_epochs', type=int, default=1, help='how many expert epochs the target params are')
    parser.add_argument('--syn_steps', type=int, default=64, help='how many steps to take on synthetic data')
    parser.add_argument('--max_start_epoch', type=int, default=10, help='max epoch we can start at')

    parser.add_argument('--num_workers',type=int,default=30,help='number of workers')
    parser.add_argument('--startIt',type=int,default=500,help='start iteration')
    parser.add_argument('--save_path',type=str, default='./result/', help='path to result')
    parser.add_argument('--do_not_increase_frame', action='store_true', help='increase frame')
    parser.add_argument('--starting_frame_num', type=int, default=2)
    parser.add_argument('--init', type=str, default='noise', choices=['noise', 'real'], help='initialization method')
    parser.add_argument('--add_frame_init', type=str, default='interpolated', choices=['noise', 'interpolated'], help='initialization method for added frames')
    parser.add_argument('--momentum_reduction', type=float, default=0.3, help='momentum reduction factor')
    parser.add_argument('--train_lr', action='store_true', help='train lr')
    parser.add_argument('--random_addition', action='store_true', help='random addition of frames')
    parser.add_argument('--l2_based_thresh', action='store_true', help='l2 regularization')
    parser.add_argument('--no_warmup', action='store_true', help='no warmup')
    parser.add_argument('--no_cooldown', action='store_true', help='no cooldown')

    args = parser.parse_args()

    main(args)
