import os
import math
from decimal import Decimal
TF_ENABLE_ONEDNN_OPTS=0

import utility
from utility import *
import numpy as np
import torch
import torch.nn.utils as utils
from tqdm import tqdm
import shade
from shade import model_dict_to_vector, model_vector_to_dict
from torch.utils.data import dataloader
from torch.utils.tensorboard import SummaryWriter
from functools import partial
from typing import Callable
from plot_utils import plot_loss, plot_paras


import data
import model as _model
import loss as _loss
from option import args
import copy

import logging
import random
import wandb
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from data_util import create_loader
from timm.utils import *
import torch.distributed as dist
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
_logger = logging.getLogger('de_train')

logging.basicConfig(level=logging.DEBUG, filename=args.log_dir, filemode='a')

torch.manual_seed(args.seed)
def main():
    print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Setting Setting')
    print(args)
    print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Setting Setting')
    checkpoint = utility.checkpoint(args)
    ckp = checkpoint

    # loader = data.Data(args)
    model = _model.Model(args, checkpoint)
    loss = _loss.Loss(args, checkpoint) if not args.test_only else None
    writer=SummaryWriter("logs") #os.path.join(output_dir, 'summary.csv')
    setup_default_logging()
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        print('distributed:', args.distributed)
        if args.distributed and args.num_gpu > 1:
            _logger.warning(
                'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)

    torch.manual_seed(args.seed + args.rank)
    # random.seed(args.seed + args.rank)
    np.random.seed(args.seed + args.rank)
    torch.cuda.manual_seed_all(args.seed + args.rank)

    error_last = 1e8
    population = load_populaton(args.pop_init_dir, model.to('cpu'), args.popsize)

    args.sync_bn = True
    model.cuda()
    model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1

    output_dir = args.output_dir
    # eval_metrics = OrderedDict([('ssim', ssim_lst), ('psnr', psnr_lst)])
    eval_metrics = OrderedDict([('ssim', []), ('psnr', [])])
    loader_test = data.Data(args).loader_test
    for i in range(args.popsize): #!!!
        solution = population[i]
        model_weights_dict = model_vector_to_dict(model=model, weights_vector=solution)
        model.load_state_dict(model_weights_dict)
        eval_metrics_temp = validate(0, model, loader_test, args, ckp)
        # eval_metrics_loss.append(eval_metrics_temp['loss'])
        eval_metrics['ssim'].append(eval_metrics_temp['ssim'])
        eval_metrics['psnr'].append(eval_metrics_temp['psnr'])
        # eval_metrics['eval_loss'][i] = round(eval_metrics_temp['loss'], 4)
    ckp.write_log('loader test length:', len(loader_test))
    de_dataset = data.div2k.DIV2K(args, train=True, name='DIV2K', de=True)
    de_dataset.set_scale(0)

    loader_de_args = dict(
        batch_size=args.mini_batch_size,
        is_training=True,
        num_workers=args.n_threads,
        use_prefetcher=False, #args.prefetcher,
        pin_memory=not args.cpu,
        persistent_workers=False,
        worker_seeding='all',
        )

    # ***********************************************************************************************************
    # need to initialize in the main
    import copy
    popsize = args.popsize
    max_iters = args.de_iters
    memory_size, bounds, lp, cr_init, f_init, k_ls = 5, 1, 5, 0.2, 0.1, [0,0,0,0]
    dim = len(model_dict_to_vector(model))
    # Initialize memory of control settings
    u_f = np.ones((memory_size,4)) * f_init
    u_cr = np.ones((memory_size,4)) * cr_init
    u_freq = np.ones((memory_size,4)) * 0.5
    ns_1, nf_1, ns_2, nf_2, dyn_list_nsf = [], [], [], [], []
    p1_c, p2_c, p3_c, p4_c = 0.25, 0.25, 0.25, 0.25
    paras1 = [lp, cr_init, f_init, bounds, dim, popsize, max_iters]

    paras2 = [p1_c, p2_c, p3_c, p4_c, ns_1, nf_1, ns_2, nf_2, u_freq, u_f, u_cr, k_ls, dyn_list_nsf] 
    # plot
    plot_variables = [u_freq_mean, u_f_mean, u_cr_mean, epoch_num, epoch_num_2, \
                        cons_sim_list, l2_dist_list, lowest_dist_list, mean_dist_list, largest_dist_list, \
                        L1_value_list, L2_value_list] = [[] for i in range(12)]
    # ***********************************************************************************************************

    for epoch in range(args.de_epochs):
        ckp.write_log('[Epoch {}]\t'.format(epoch))
        timer_de = utility.timer()
        timer_de.tic()

        loader_de_args['worker_seeding'] = epoch + 233
        loader_de = create_loader(de_dataset, **loader_de_args)
        if args.distributed:
            loader_de.sampler.set_epoch(epoch)
        model.eval()        
    
        dist.barrier()
        rand_lambda = args.rand_lambda # 0.1
        score_lst = score_func(model, population, loader_de, args, ckp,rand_lambda) 
        print('score_lst..............', score_lst)
        # print(score_lst:)[tensor(1.1641, device='cuda:0', dtype=torch.float16), ,,,]
        if args.local_rank == 0:
            bestidx = score_lst.index(min(score_lst))
            worstidx = score_lst.index(max(score_lst))
            de_iter_loss = [round(-j.item(), 4) for j in score_lst]
            _logger.info('de_iter:{}, best_score:{:>7.4f}, best_idx:{}, de_iter_loss: {}'.format(
                                                             0, min(score_lst), bestidx, de_iter_loss))
            de_iter_dict = OrderedDict([('iter', 0), ('bestidx', bestidx), ('train_loss', de_iter_loss)])
            update_summary(epoch, de_iter_dict, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=True)
        dist.barrier()

        de_iter_time_m = AverageMeter()
        end = time.time()
        # rand_lambda = np.random.rand()
        rand_lambda = args.rand_lambda #0.1
        for de_iter in range(1, args.de_iters+1):
            # import pdb; pdb.set_trace()               
            evolve_out = shade.evolve(score_func, epoch, de_iter, population, score_lst, paras1, paras2, \
                                            model, loader_de, args, ckp, rand_lambda)
            if args.local_rank == 0:
                population, score_lst, bestidx, worstidx, dist_matrics, paras2, update_label = evolve_out
                p1_c, p2_c, p3_c, p4_c, ns_1, nf_1, ns_2, nf_2, u_freq, u_f, u_cr, k_ls, dyn_list_nsf = paras2
                cons_sim, l2_dist, lowest_dist, mean_dist, largest_dist, L1_value, L2_value = dist_matrics
                best_score = score_lst[bestidx]  
                de_iter_loss = [round(-j.item(), 4) for j in score_lst]
                _logger.info('de_iter:{}, best_score:{:>7.4f}, best_idx:{}, de_iter_loss: {}'.format(
                               de_iter, best_score, bestidx, de_iter_loss))
                ckp.write_log('de_iter:{}, best_score:{:>7.4f}, best_idx:{}, de_iter_loss: {}'.format(0, min(score_lst), bestidx, de_iter_loss))
                pop_tensor = torch.stack(population)
                if epoch%2 == 0 and de_iter == 1:
                    pop_save = OrderedDict([('epoch', epoch), ('de_iter', de_iter), ('pop', torch.stack(population))])  #torch.stack(population)
                    torch.save(pop_save, os.path.join(output_dir, 'pop_save'+'_'+str(epoch)+'_'+str(de_iter)+'_'+str(args.local_rank)+'.pt'))
                os.makedirs('./.cache2/', exist_ok=True)
                torch.save(pop_tensor,'./.cache2/pop_tensor.pt') 
            if args.local_rank != 0: 
                update_label = list(range(popsize))
                
            dist.barrier() 
           
            torch.distributed.broadcast_object_list(update_label, src=0)
            pop_tensor = torch.load('./.cache2/pop_tensor.pt', map_location='cpu').cuda()
            dist.barrier() 
            population = list(pop_tensor)

            for i in range(popsize): #!!!
                if update_label[i] == 1:
                    solution = population[i]
                    model_weights_dict = model_vector_to_dict(model=model, weights_vector=solution)
                    model.load_state_dict(model_weights_dict)
                    eval_metrics_temp = validate(epoch, model, loader_test, args, ckp)
                    eval_metrics['ssim'][i] = round(eval_metrics_temp['ssim'], 4)
                    eval_metrics['psnr'][i] = round(eval_metrics_temp['psnr'], 4)
              
            ckp.write_log('eval_metrics_ssim: {}'.format(eval_metrics['ssim']))
            ckp.write_log('eval_metrics_psnr: {}'.format(eval_metrics['psnr']))
            torch.cuda.synchronize()
            de_iter_time_m.update(time.time() - end)
            end = time.time()

            if args.local_rank == 0:
                ckp.write_log(
                     'DE: {} [de_iter: {}]  '
                     'SSIM: {ssim:>7.4f}  '
                     'PSNR: {psnr:>7.4f}  '
                     'Iter_time: {de_iter_time.val:.3f}s, {rate:>7.2f}/s  '.format(
                         epoch, de_iter,
                         ssim = eval_metrics['ssim'][bestidx],
                         psnr = eval_metrics['psnr'][bestidx],
                         de_iter_time=de_iter_time_m,
                         rate= args.de_batch_size / de_iter_time_m.val))    #eval_metrics_temp
                
                # torch.save(population,'/root/declc/pth/population.pth')
                train_metrics = OrderedDict([('iter', de_iter), ('bestidx', bestidx), ('de_iter_loss', de_iter_loss)])
                pop_key = [str(i) for i in range(args.popsize)]
                writer.add_scalars("train_metrics", dict(zip(pop_key, torch.tensor(de_iter_loss, dtype=torch.float32))), epoch*args.de_iters+de_iter)
                update_summary(
                    epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                    write_header=True)

                de_iter_dict = OrderedDict([('iter', de_iter), ('bestidx', bestidx), 
                                         ('train_loss', de_iter_loss)])
                parameter1=OrderedDict([('cons_sim', cons_sim), ('l2_dist', l2_dist), 
                                         ('lowest_dist', lowest_dist), ('mean_dist', mean_dist), ('largest_dist', largest_dist)])
                parameter2=OrderedDict([('u_freq', u_freq), ('u_f', u_f), ('u_cr', u_cr), ('L1_value', L1_value), ('L2_value', L2_value)])
                
                # update_summary(epoch, de_iter_dict, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=True)
                update_summary(epoch, parameter1, parameter2, os.path.join(output_dir, 'summary2.csv'), write_header=True)

                u_freq_mean.append(np.mean(u_freq, axis=0))
                u_f_mean.append(np.mean(u_f, axis=0))
                u_cr_mean.append(np.mean(u_cr, axis=0))
                epoch_num.append([epoch*args.de_iters + de_iter for _ in range(len(np.mean(u_cr, axis=0)))])

                L1_value_list.append(L1_value.cpu().numpy())
                L2_value_list.append(L2_value.cpu().numpy())
                epoch_num_2.append([epoch*args.de_iters + de_iter for _ in range(len(L1_value))])

                cons_sim_list.append(cons_sim.item())
                l2_dist_list.append(l2_dist.item())
                lowest_dist_list.append(lowest_dist.item())
                mean_dist_list.append(mean_dist.item())
                largest_dist_list.append(largest_dist.item())
                # plot_paras(epoch, de_iter, max_iters, output_dir, plot_variables, wandb)

            dist.barrier() 
        dist.barrier()
    return

def validate(epoch, model, loader_test, args, ckp):
    torch.set_grad_enabled(False)
    ckp.write_log('\nEvaluation:')
    model.eval()

    timer_test = utility.timer()
    if args.save_results: ckp.begin_background()
    for idx_data, d in enumerate(loader_test):
        for idx_scale, scale in enumerate(args.scale):
            d.dataset.set_scale(idx_scale)
            psnr = 0
            ssim = 0
            for lr, hr, filename in d:
                lr, hr = prepare(lr, hr)
                sr = model(lr, idx_scale)
                sr = utility.quantize(sr, args.rgb_range)
                save_list = [sr]
                psnr1 = utility.calc_psnr(sr, hr, scale, args.rgb_range, dataset=d)
                ssim1 = utility.calc_ssim(sr, hr, scale, dataset=d)
                if args.distributed:
                    reduced_ssim = reduce_tensor(torch.tensor(ssim1).cuda(), args.world_size) #+0.0001*torch.sum(solution**2)
                    reduced_psnr = reduce_tensor(torch.tensor(psnr1).cuda(), args.world_size) #+0.0001*torch.sum(solution**2)
                psnr += reduced_psnr.item()
                ssim += reduced_ssim.item()
            psnr /= len(d)
            ssim /= len(d)

    ckp.write_log('Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True)
    metrics = OrderedDict([('psnr', psnr), ('ssim', ssim)])
    return metrics


def score_func(model, population, loader_de, args, ckp, rand_lambda=0.5):
# def score_func(model, population, loader, ckp, epoch, rand_lambda=1.0):
    popsize = len(population)
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    ssim_all = torch.zeros(popsize).tolist()
    psnr_all = torch.zeros(popsize).tolist()
    L1_all = torch.zeros(popsize).tolist()

    torch.set_grad_enabled(False)
    model.eval()
    end = time.time()
    for batch_idx, (lr, hr, filename) in enumerate(loader_de):
        # if batch_idx >= 5: break
        if batch_idx >= (args.de_batch_size//args.mini_batch_size): break
        data_time_m.update(time.time() - end)
        for i in range(0, popsize):
            solution = population[i]
            model_weights_dict = model_vector_to_dict(model=model, weights_vector=solution)
            model.load_state_dict(model_weights_dict)
            lr, hr = prepare(lr, hr)
            # if batch_idx==0 and args.local_rank < 1 and i == 0:
            if batch_idx==0 and i == 0:
                ckp.write_log('pop: {} lr: {}'.format(i, lr.flatten()[6000:6008]))
            sr = model(lr, 0)
            sr = utility.quantize(sr, 255)
            ssim = utility.calc_ssim(sr, hr, args.scale[0])
            psnr = utility.calc_psnr(sr, hr, args.scale[0], args.rgb_range)
            l1 = torch.mean(torch.abs(sr-hr))
            # print(ssim)
            if args.distributed:
                reduced_ssim = reduce_tensor(torch.tensor(copy.deepcopy(ssim)).cuda(), args.world_size) #+0.0001*torch.sum(solution**2)
                reduced_psnr = reduce_tensor(torch.tensor(copy.deepcopy(psnr)).cuda(), args.world_size) #+0.0001*torch.sum(solution**2)
                reduced_l1 = reduce_tensor(torch.tensor(copy.deepcopy(l1)).cuda(), args.world_size) #+0.0001*torch.sum(solution**2)

            # print('ssim:', reduced_ssim, 'psnr:', reduced_psnr, 'L1:', reduced_l1)
            ssim_all[i] += reduced_ssim.cpu().numpy()
            psnr_all[i] += reduced_psnr.cpu().numpy()
            L1_all[i] += reduced_l1.item()
        batch_time_m.update(time.time() - end)
        end = time.time()
        ckp.write_log('batch: {} '
              'data_time: {time1.val:.3f} ({time1.avg:.3f})  '
              'batch_time: {time2.val:.3f} ({time2.avg:.3f})  '.format(batch_idx, time1=data_time_m, time2=batch_time_m))                 
    ssim_lst = [i/(args.de_batch_size//args.mini_batch_size) for i in ssim_all]#!!!
    psnr_lst = [i/(args.de_batch_size//args.mini_batch_size) for i in psnr_all]#!!!
    l1_lst = [i/(args.de_batch_size//args.mini_batch_size) for i in L1_all]#!!!

    score_norm_ss = normal_score([-1*i for i in ssim_lst])
    score_norm_ps = normal_score([-1*i for i in psnr_lst])
    score_norm_l1 = normal_score(l1_lst)

    mean_ss = np.mean(score_norm_ss)
    mean_ps = np.mean(score_norm_ps)
    mean_l1 = np.mean(score_norm_l1)

    diff_ss = np.array(score_norm_ss) - mean_ss
    diff_ps = np.array(score_norm_ps) - mean_ps
    diff_l1 = np.array(score_norm_l1) - mean_l1

    exp_diff_ss = np.exp(diff_ss)
    exp_diff_ps = np.exp(diff_ps)
    exp_diff_l1 = np.exp(diff_l1)

    if args.mode == 'psnr':
        denominator = exp_diff_ss + exp_diff_ps
        pss_array = exp_diff_ps / denominator
        pps_array = exp_diff_ss / denominator

        answer_lst = [(1-rand_lambda) * p_ss * ss + rand_lambda * p_ps * ps for p_ss, p_ps, ss, ps in zip(pss_array, pps_array, score_norm_ss, score_norm_ps)]
    else:
        denominator = exp_diff_ss + exp_diff_l1
        pss_array = exp_diff_l1 / denominator
        pl1_array = exp_diff_ss / denominator
        answer_lst = [(1-rand_lambda) * p_ss * ss + rand_lambda * p_l1 * l for p_ss, p_l1, ss, l in zip(pss_array, pl1_array, score_norm_ss, score_norm_ps)]
    return answer_lst



def load_populaton(pop_init_dir, model, popsize):
    print(model.device)
    population = []
    batch_idx = []
    for file in os.listdir(pop_init_dir):
        if len(population) >= popsize: break
        elif file.split('_')[0] == 'net' or  file.split('_')[1] == 'net':
            resume_path = os.path.join(args.pop_init_dir, file)
            last_name= file.split('_')[-1] # eg:377.pth
            idx = int(last_name.split('.')[0])
            batch_idx.append(idx)
            print('curent idx........', idx)
            print('resume_path......', resume_path)
            print(model.device)
            print('>>>>>>>>>>>>>>>>>>>>>')
            model.load_state_dict(torch.load(resume_path, map_location='cpu'), strict=True)
            solution = model_dict_to_vector(model).detach()
            population.append(solution)
    combined = list(zip(batch_idx, population))
    combined_sorted = sorted(combined, key=lambda x: x[0])
    batch_idx_sorted, population_sorted = zip(*combined_sorted)
    print('current batch idx', list(batch_idx_sorted))
    return list(population_sorted)

def load_populaton_raw(pop_init_dir, model, popsize):
    psnr_lst = torch.load(os.path.join(args.pop_init_dir, 'psnr_log.pt'))
    ssim_lst = torch.load(os.path.join(args.pop_init_dir, 'ssim_log.pt'))
    psnr_lst = psnr_lst.flatten().tolist()
    ssim_lst = ssim_lst.flatten().tolist()
    print(model.device)
    population = []
    for file in os.listdir(pop_init_dir):
        if len(population) >= popsize: break
        elif file.split('_')[0] == 'net' or  file.split('_')[1] == 'net':
            resume_path = os.path.join(args.pop_init_dir, file)
            print('resume_path......', resume_path)
            torch.cuda.empty_cache()
            model.load_state_dict(torch.load(resume_path), strict=True)
            solution = model_dict_to_vector(model)
            population.append(solution)
    return population#, psnr_lst, ssim_lst

def normal_score(score_lst):
    max_v = max(score_lst)
    min_v = min(score_lst)
    return [(i-min_v)/(max_v-min_v) for i in score_lst]

def prepare(lr, hr):
    if args.cpu:
        device = torch.device('cpu')
    else:
        if torch.backends.mps.is_available():
            device = torch.device('mps')
        elif torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    def _prepare(tensor):
        if args.precision == 'half': tensor = tensor.half()
        return tensor.to(device)

    return [_prepare(a) for a in [lr, hr]]


if __name__ == '__main__':
    main()
    
    











