import torch 
import argparse
import numpy as np  
from utils  import get_network, set_dataset_specs, get_normalize_trans, eval,eval_mp , ImageFolderSubsampleTwins, cutmix
from torchvision import transforms      
from tqdm import tqdm       
import torchvision
import os 
import matplotlib.pyplot as plt 
import time 
from torch.optim.lr_scheduler import LambdaLR
import math 
import random
import multiprocessing  
from diffusers import AutoPipelineForImage2Image, StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import torch.multiprocessing as mp 
from accelerate import PartialState 
import torch.distributed as dist

import warnings  
warnings.simplefilter("ignore")  


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
print(device)  


def generate_slerp_interpolations(x, x_twin):
    batch_size, l_c, l_h, l_w = x.shape     

    alphas = torch.rand(batch_size, device=x.device).view(batch_size, 1, 1, 1)  # Shape: [batch_size, 1, 1, 1]

    v0 = x  
    v1 = x_twin 

    v0 = v0.reshape(batch_size, -1) 
    v1 = v1.reshape(batch_size, -1) 


    v0_norm = v0 / v0.norm(dim=1, keepdim=True)
    v1_norm = v1 / v1.norm(dim=1, keepdim=True) 

    v0 = v0.reshape(batch_size, l_c, l_h, l_w) 
    v1 = v1.reshape(batch_size, l_c, l_h, l_w) 

    dot = (v0_norm * v1_norm).sum(dim=1, keepdim=True).reshape(batch_size, 1, 1, 1)  

    theta_0 = torch.acos(dot)
    sin_theta_0 = torch.sin(theta_0)

    theta_t = theta_0 * alphas
    sin_theta_t = torch.sin(theta_t)
    s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0

    # interpolated_latents = s0 * v0 + s1 * v1  
    interpolated_latents = alphas * v0 + (1 - alphas) * v1 
    
    return interpolated_latents

def train_soft_lbls(args):    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)       
    torch.backends.cudnn.deterministic = True          
    random.seed(args.seed)      

    if args.teacher_arch == None:
        args.teacher_arch = args.arch

    model = get_network(args, args.arch, pretrained=False, data_parallel=True)    

    if args.resume_training == True:        
        model.load_state_dict(torch.load(args.model_ckpt_student,  map_location='cpu'))       
        print(f"Model loaded from {args.model_ckpt_student}")   

    model_teacher = get_network(args, args.teacher_arch, pretrained=True, data_parallel=True)   
    model_teacher.eval()  

    num_gpus = torch.cuda.device_count()

    
    normalize, ____ = get_normalize_trans(args)   

    train_trans_later = transforms.Compose([transforms.RandomResizedCrop(args.input_size, scale=(args.rnd_res_scale_st, 1)), transforms.RandomHorizontalFlip(), normalize])            
    print(f'training trans size: {args.input_size}, rnd_res_scale_st: {args.rnd_res_scale_st}')     

    train_trans = transforms.Compose([transforms.ToTensor()])     
    
    # if args.subset.startswith('cifar') :
        
    # else:
        

    if args.subset.startswith('imagenet') :
        ds_tst = transforms.Compose([transforms.Resize(args.init_resize), transforms.CenterCrop(args.input_size), transforms.ToTensor(), normalize]) 
    else:
        ds_tst = transforms.Compose([transforms.ToTensor(), normalize]) 

    

    print(f'test trans initial resize: {args.init_resize}, input size: {args.input_size}')     
    
    syn_path = os.path.join(args.syn_data_path, f'lvl_{args.lvl}') if args.lvl is not None else args.syn_data_path  

    if 'train' in os.listdir(syn_path):     
            syn_path = os.path.join(syn_path, 'train')   

    if 'ipc_' in os.listdir(syn_path):          
        syn_path = os.path.join(syn_path, f'ipc_{args.ipc}')        

    # print(f'reading the training images from {syn_path}')   
    # print(os.listdir(syn_path))         


    pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")

    distributed_state = PartialState()
    pipe.to(distributed_state.device)
    # pipe.to(device)



        
    ds_train = ImageFolderSubsampleTwins(syn_path, samples_per_class=-1, transform=train_trans)     
    print('sub samples were loaded with sample per class:', args.ipc)    
    print(f'training images are loaded from {syn_path}')            
    
    if args.subset == 'cifar100':
        ds_tst = torchvision.datasets.CIFAR100(root=args.val_dir, train=False, download=True, transform=ds_tst)     
    elif args.subset == 'cifar10':
        ds_tst = torchvision.datasets.CIFAR10(root=args.val_dir, train=False, download=True, transform=ds_tst)  
    else:    
        ds_tst = torchvision.datasets.ImageFolder(os.path.join(args.val_dir, 'val'), transform=ds_tst)        


    print(len(ds_train), len(ds_tst))    

    train_dl = torch.utils.data.DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)       
    tst_dl = torch.utils.data.DataLoader(ds_tst, batch_size=128, shuffle=False, num_workers=args.workers)


    if args.init_acc and dist.get_rank() == 0:  
        acc = eval(model_teacher, tst_dl)    
        print(f"Initial Accuracy: {np.round(acc, 2)}")

    if args.init_acc_st and dist.get_rank() == 0:  
        acc = eval(model, tst_dl)    
        print(f"Initial Accuracy: {np.round(acc, 2)}")


    criterion = torch.nn.KLDivLoss(reduction="batchmean")
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd, betas=(0.9, 0.999))         
    sch = LambdaLR(optimizer, lambda step: 0.5 * (1.0 + math.cos(math.pi * step / args.epochs / 2)) if step <= args.epochs else 0, last_epoch=-1)

    best_acc = -1
    loss_ = []

    checkpoint_dir  = os.path.join('trained_checkpoints', args.extra_desc)


    class_folder_names = sorted([f for f in os.listdir(os.path.join(args.syn_data_path)) if f.startswith('.') == False])
    dataset_dict = torch.load(args.dataset_name_dict)
    class_names = [dataset_dict[folder_name].replace('_', ' ') for folder_name in class_folder_names]          
        
    print('.....................Training with soft labels.....................')

    resize_up = transforms.Resize(320)


    
    timesteps, num_inference_steps = pipe.get_timesteps(
        5,
        args.strength,
        'cuda',
        denoising_start= None,
    )
    latent_timestep = timesteps[:1].repeat(1 * 1)

    for epoch in range(args.epochs):   
        model.train()       
        t1 = time.time()     
        # for b_ind, (x, y) in enumerate(train_dl):        
        generated = False   
        for b_ind, (x, x_twin, y) in tqdm(enumerate(train_dl), total=len(train_dl)):
            x_lq, y = x.to(device), y.to(device)
            x_twin = x_twin.to(device)      

            # folder_names_in_batch = [class_folder_names[int(y_.item())] for y_ in y]          
            prompts = [f'a photo of a {class_names[int(y_.item())]}' for y_ in y]       
            # for p in prompts:
            #     print(p)        

            x_lq = resize_up(x_lq)        
            x_twin = resize_up(x_twin)      
            # print(x_twin.shape, 'dfkdfjldkfjd')
            
            in_dict = {'images': x_lq, 'prompts': prompts, 'labels': y, 'images_twin': x_twin}      
            with distributed_state.split_between_processes(in_dict) as dic:
                with torch.no_grad():    
                    
                    drop_val = np.random.uniform(0., 1, size=1)            
                    if drop_val > args.aug_drop_p:      
                        
                        x_gen = pipe.image_processor.preprocess(dic['images']).to('cuda')       
                        x_gen = pipe.prepare_latents(x_gen, latent_timestep, 1, 1, generator=None, add_noise=False, device="cuda", dtype=torch.float16) 

                        x_twin_gen = pipe.image_processor.preprocess(dic['images_twin']).to('cuda')     
                        x_twin_gen = pipe.prepare_latents(x_twin_gen, latent_timestep, 1, 1, generator=None, add_noise=False, device="cuda", dtype=torch.float16)   

                        x_gen = generate_slerp_interpolations(x_gen, x_twin_gen)    
                        x_gen = pipe(prompt=dic['prompts'], image=x_gen, num_inference_steps=5, strength=args.strength, guidance_scale=0., 
                                        num_images_per_prompt=1, output_type='pt').images
                        

                        from PIL import Image       
                        if generated == False and dist.get_rank() == 0:
                            generated = True    
                            for i in range(len(x_gen)): 
                                Image.fromarray(np.uint8(x_gen[i].cpu().numpy().transpose(1, 2, 0) *  255.)).resize((args.input_size,args.input_size)).save(f'gen_{i}.png')       
                    
                        
                    else:
                        x_gen = dic['images']

                    local_y = dic['labels']     
                    

            gathered_results = [torch.zeros_like(x_gen) for _ in range(dist.get_world_size())]
            dist.all_gather(gathered_results, x_gen)

            gathered_labels = [torch.zeros_like(local_y) for _ in range(torch.distributed.get_world_size())]
            dist.all_gather(gathered_labels, local_y)

            x = torch.cat(gathered_results, dim=0)
            gathered_labels = torch.cat(gathered_labels, dim=0)
        
            x =  train_trans_later(x) 
            
            if args.cutmix == True:
                x = cutmix(x)  

    
            out = model(x)  
            out_soft = torch.log_softmax(out / args.temp, dim=1)        

            
            with torch.no_grad():       
                out_teacher = model_teacher(x)  
                soft_lbls = torch.softmax(out_teacher / args.temp, dim=1)       


            loss = criterion(out_soft, soft_lbls)           

            optimizer.zero_grad()   
            loss.backward() 
            optimizer.step()        
            loss_.append(loss.item())
        
        sch.step()  

        if dist.get_rank() == 0:
            if (epoch+1) % 1 == 0 :
                acc_tst = eval(model, tst_dl)       
                    
                if acc_tst > best_acc:      
                    best_acc = acc_tst

                    best_state_dict_cpu = {key: value.cpu() for key, value in model.state_dict().items()} 
                    checkpoint_save_name = f"model_{args.extra_desc}_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}_seed_{args.seed}.pth"     
                    torch.save(best_state_dict_cpu, checkpoint_save_name)     
                                                               

                print(f"Epoch: {epoch}, tst_acc: {np.round(acc_tst, 2)}, best_tst_acc: {np.round(best_acc, 2)}")        
        
            t2 = time.time()        
            elapsed_time = t2 - t1      
            print(f"Elapsed time: {elapsed_time:.2f} seconds") 
    
    #concatenate the best acc to the save name 
    if dist.get_rank() == 0:
        checkpoint_save_name = f"model_{args.extra_desc}_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}_acc_{np.round(best_acc, 2)}_seed_{args.seed}.pth"     
        torch.save(best_state_dict_cpu, checkpoint_save_name)     

    print(f'final best acc: {best_acc}')
                    

    return best_acc  



if __name__ == "__main__":

    parser = argparse.ArgumentParser()        
    parser.add_argument("--cls_ind", type=int, default=None)    

    parser.add_argument("--subset", type=str, default="imagenette")
    parser.add_argument("--size", type=int, default=224)    
    parser.add_argument("--st_cls", type=int, default=0)    
    parser.add_argument("--end_cls", type=int, default=None)    
    parser.add_argument("--eval_ratio", type=float, default=0.)

    parser.add_argument("--extra_desc", type=str, default="")
    parser.add_argument("--model_ckpt", type=str, default=None)
    parser.add_argument("--model_ckpt_student", type=str, default=None)
    
    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--init_resize", type=int, default=256)
    parser.add_argument("--input_size", type=int, default=None)
    parser.add_argument("--diff_input_size", type=int, default=None)
    
    parser.add_argument("--root_dir", type=str, default=None)   
    parser.add_argument("--val_dir", type=str, default=None)   
    parser.add_argument("--root_aug", type=str, default=None)   

    parser.add_argument("--emb_root_dir", type=str, default=None)   
    parser.add_argument("--emb_file_name", type=str, default=None)   

    parser.add_argument("--arch", type=str, default='resnet18')
    parser.add_argument("--teacher_arch", type=str, default=None)
    
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=200)  
    parser.add_argument("--lr", type=float, default=0.1)        
    parser.add_argument("--wd", type=float, default=1e-4)        
    parser.add_argument("--mom", type=float, default=.9)        
    parser.add_argument("--init_acc", action='store_true', default=False)
    parser.add_argument("--init_acc_st", action='store_true', default=False)

    parser.add_argument("--force_rded_net", action='store_true', default=False)
    parser.add_argument("--resume_training", action='store_true', default=False)

    parser.add_argument("--use_json_ds", action='store_true', default=False)
    parser.add_argument("--json_file_path", type=str, default=None)   

    parser.add_argument("--aug_drop_p", type=float, default=0.)
    
    
    parser.add_argument("--connection_string", type=str, default=None)   
    parser.add_argument("--container_name", type=str, default=None)  

    parser.add_argument("--soft_lbl", action='store_true', default=False)
    
    parser.add_argument("--workers", type=int, default=8)  
    
    parser.add_argument("--aug", action='store_true', default=False)
    parser.add_argument("--full_res", action='store_true', default=False)
    parser.add_argument("--cutmix", action='store_true', default=False)


    parser.add_argument("--dataset_name_dict", type=str) 
    parser.add_argument("--temp", type=float, default=1)        
    parser.add_argument("--syn_data_path", type=str, default=None)

    parser.add_argument("--strength", type=float, default=0.7)        
    parser.add_argument("--rnd_res_scale_st", type=float, default=0.8)        
    
    parser.add_argument("--lvl", type=int, default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--exp_num", type=int, default=1)

    parser.add_argument("--ipc", type=int, default=None)

    args = parser.parse_args()  

    for key, value in vars(args).items():
        print(f"{key}: {value}")        
    
    set_dataset_specs(args)

    init_seed = args.seed
    accs_ = []  
    for exp_id in range(args.exp_num):
        args.seed = init_seed + exp_id      
        print(f'running experiment with seed {args.seed}')   
        acc = train_soft_lbls(args)
        accs_.append(acc)       

    print(f"Average acc: {np.round(np.mean(accs_), 2)} +- {np.round(np.std(accs_), 2)}")   

