
import paddle
f = open('largedit3b_t2i_sum_pd.txt','w')
model_dict = paddle.load('model_state.pdparams')

not_keys = ['num_batches_tracked'] #, 'running_mean', 'running_var']

for key in model_dict.keys():
    if 'vae.' in key or 'model_ema.' in key:
        continue
    #key = key.replace('transformer.', '')


    write_flag = 1
    for nk in not_keys:
        if nk in key:
            write_flag=0
            break
    if write_flag:
        shape = model_dict[key].shape
        if len(shape)==4:
            line = '{} [{},{},{},{}]\n'.format(key, shape[0], shape[1], shape[2], shape[3])
            print(key, '4')
        elif len(shape)==2:
            #print(key, '2')
            line = '{} [{},{}]\n'.format(key, shape[0], shape[1])
        else:
            #print(key, '1')
            try:
                line = '{} [{}]\n'.format(key, shape[0])
            except:
                print(key)
                line = '{}\n'.format(key)
        #print(line)
        line = line.replace('\n', f' {model_dict[key].sum().item()}\n')
        f.write(line)
f.close()
